summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorGard Spreemann <gspr@nonempty.org>2022-04-27 11:49:23 +0200
committerGard Spreemann <gspr@nonempty.org>2022-04-27 11:49:23 +0200
commit35bd2c98b642df78638d7d733bc1a89d873db1de (patch)
tree6bc637624004713808d3097b95acdccbb9608e52 /ot
parentc4753bd3f74139af8380127b66b484bc09b50661 (diff)
parenteccb1386eea52b94b82456d126bd20cbe3198e05 (diff)
Merge tag '0.8.2' into dfsg/latest
Diffstat (limited to 'ot')
-rw-r--r--ot/__init__.py14
-rw-r--r--ot/backend.py306
-rw-r--r--ot/bregman.py17
-rw-r--r--ot/da.py382
-rw-r--r--ot/dr.py44
-rw-r--r--ot/factored.py145
-rw-r--r--ot/gpu/__init__.py50
-rw-r--r--ot/gpu/bregman.py196
-rw-r--r--ot/gpu/da.py144
-rw-r--r--ot/gpu/utils.py101
-rw-r--r--ot/gromov.py1109
-rw-r--r--ot/lp/__init__.py123
-rw-r--r--ot/lp/cvx.py1
-rw-r--r--ot/optim.py11
-rwxr-xr-xot/partial.py84
-rw-r--r--ot/plot.py7
-rw-r--r--ot/regpath.py545
-rw-r--r--ot/stochastic.py242
-rw-r--r--ot/unbalanced.py525
-rw-r--r--ot/utils.py36
-rw-r--r--ot/weak.py124
21 files changed, 3053 insertions, 1153 deletions
diff --git a/ot/__init__.py b/ot/__init__.py
index f55819d..86ed94e 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -1,5 +1,4 @@
"""
-
.. warning::
The list of automatically imported sub-modules is as follows:
:py:mod:`ot.lp`, :py:mod:`ot.bregman`, :py:mod:`ot.optim`
@@ -7,13 +6,10 @@
:py:mod:`ot.gromov`, :py:mod:`ot.smooth`
:py:mod:`ot.stochastic`, :py:mod:`ot.partial`, :py:mod:`ot.regpath`
, :py:mod:`ot.unbalanced`.
-
The following sub-modules are not imported due to additional dependencies:
-
- :any:`ot.dr` : depends on :code:`pymanopt` and :code:`autograd`.
- :any:`ot.gpu` : depends on :code:`cupy` and a CUDA GPU.
- :any:`ot.plot` : depends on :code:`matplotlib`
-
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
@@ -36,6 +32,8 @@ from . import unbalanced
from . import partial
from . import backend
from . import regpath
+from . import weak
+from . import factored
# OT functions
from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
@@ -46,11 +44,14 @@ from .da import sinkhorn_lpl1_mm
from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance
from .gromov import (gromov_wasserstein, gromov_wasserstein2,
gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2)
+from .weak import weak_optimal_transport
+from .factored import factored_optimal_transport
+
# utils functions
from .utils import dist, unif, tic, toc, toq
-__version__ = "0.8.1.0"
+__version__ = "0.8.2"
__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils',
'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
@@ -59,5 +60,6 @@ __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils',
'sinkhorn_unbalanced', 'barycenter_unbalanced',
'sinkhorn_unbalanced2', 'sliced_wasserstein_distance',
'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2',
- 'max_sliced_wasserstein_distance',
+ 'max_sliced_wasserstein_distance', 'weak_optimal_transport',
+ 'factored_optimal_transport',
'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath']
diff --git a/ot/backend.py b/ot/backend.py
index 58b652b..361ffba 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -87,7 +87,9 @@ Performance
# License: MIT License
import numpy as np
-import scipy.special as scipy
+import scipy
+import scipy.linalg
+import scipy.special as special
from scipy.sparse import issparse, coo_matrix, csr_matrix
import warnings
import time
@@ -102,7 +104,7 @@ except ImportError:
try:
import jax
import jax.numpy as jnp
- import jax.scipy.special as jscipy
+ import jax.scipy.special as jspecial
from jax.lib import xla_bridge
jax_type = jax.numpy.ndarray
except ImportError:
@@ -202,13 +204,29 @@ class Backend():
def __str__(self):
return self.__name__
- # convert to numpy
- def to_numpy(self, a):
+ # convert batch of tensors to numpy
+ def to_numpy(self, *arrays):
+ """Returns the numpy version of tensors"""
+ if len(arrays) == 1:
+ return self._to_numpy(arrays[0])
+ else:
+ return [self._to_numpy(array) for array in arrays]
+
+ # convert a tensor to numpy
+ def _to_numpy(self, a):
"""Returns the numpy version of a tensor"""
raise NotImplementedError()
- # convert from numpy
- def from_numpy(self, a, type_as=None):
+ # convert batch of arrays from numpy
+ def from_numpy(self, *arrays, type_as=None):
+ """Creates tensors cloning a numpy array, with the given precision (defaulting to input's precision) and the given device (in case of GPUs)"""
+ if len(arrays) == 1:
+ return self._from_numpy(arrays[0], type_as=type_as)
+ else:
+ return [self._from_numpy(array, type_as=type_as) for array in arrays]
+
+ # convert an array from numpy
+ def _from_numpy(self, a, type_as=None):
"""Creates a tensor cloning a numpy array, with the given precision (defaulting to input's precision) and the given device (in case of GPUs)"""
raise NotImplementedError()
@@ -536,6 +554,16 @@ class Backend():
"""
raise NotImplementedError()
+ def argmin(self, a, axis=None):
+ r"""
+ Returns the indices of the minimum values of a tensor along given dimensions.
+
+ This function follows the api from :any:`numpy.argmin`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.argmin.html
+ """
+ raise NotImplementedError()
+
def mean(self, a, axis=None):
r"""
Computes the arithmetic mean of a tensor along given dimensions.
@@ -786,6 +814,72 @@ class Backend():
"""
raise NotImplementedError()
+ def solve(self, a, b):
+ r"""
+ Solves a linear matrix equation, or system of linear scalar equations.
+
+ This function follows the api from :any:`numpy.linalg.solve`.
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.solve.html
+ """
+ raise NotImplementedError()
+
+ def trace(self, a):
+ r"""
+ Returns the sum along diagonals of the array.
+
+ This function follows the api from :any:`numpy.trace`.
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.trace.html
+ """
+ raise NotImplementedError()
+
+ def inv(self, a):
+ r"""
+ Computes the inverse of a matrix.
+
+ This function follows the api from :any:`scipy.linalg.inv`.
+
+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.inv.html
+ """
+ raise NotImplementedError()
+
+ def sqrtm(self, a):
+ r"""
+ Computes the matrix square root. Requires input to be definite positive.
+
+ This function follows the api from :any:`scipy.linalg.sqrtm`.
+
+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.sqrtm.html
+ """
+ raise NotImplementedError()
+
+ def isfinite(self, a):
+ r"""
+ Tests element-wise for finiteness (not infinity and not Not a Number).
+
+ This function follows the api from :any:`numpy.isfinite`.
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.isfinite.html
+ """
+ raise NotImplementedError()
+
+ def array_equal(self, a, b):
+ r"""
+ True if two arrays have the same shape and elements, False otherwise.
+
+ This function follows the api from :any:`numpy.array_equal`.
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.array_equal.html
+ """
+ raise NotImplementedError()
+
+ def is_floating_point(self, a):
+ r"""
+ Returns whether or not the input consists of floats
+ """
+ raise NotImplementedError()
+
class NumpyBackend(Backend):
"""
@@ -802,10 +896,10 @@ class NumpyBackend(Backend):
rng_ = np.random.RandomState()
- def to_numpy(self, a):
+ def _to_numpy(self, a):
return a
- def from_numpy(self, a, type_as=None):
+ def _from_numpy(self, a, type_as=None):
if type_as is None:
return a
elif isinstance(a, float):
@@ -936,6 +1030,9 @@ class NumpyBackend(Backend):
def argmax(self, a, axis=None):
return np.argmax(a, axis=axis)
+ def argmin(self, a, axis=None):
+ return np.argmin(a, axis=axis)
+
def mean(self, a, axis=None):
return np.mean(a, axis=axis)
@@ -955,7 +1052,7 @@ class NumpyBackend(Backend):
return np.unique(a)
def logsumexp(self, a, axis=None):
- return scipy.logsumexp(a, axis=axis)
+ return special.logsumexp(a, axis=axis)
def stack(self, arrays, axis=0):
return np.stack(arrays, axis)
@@ -1004,8 +1101,11 @@ class NumpyBackend(Backend):
else:
return a
- def where(self, condition, x, y):
- return np.where(condition, x, y)
+ def where(self, condition, x=None, y=None):
+ if x is None and y is None:
+ return np.where(condition)
+ else:
+ return np.where(condition, x, y)
def copy(self, a):
return a.copy()
@@ -1046,6 +1146,27 @@ class NumpyBackend(Backend):
results[key] = (t1 - t0) / n_runs
return results
+ def solve(self, a, b):
+ return np.linalg.solve(a, b)
+
+ def trace(self, a):
+ return np.trace(a)
+
+ def inv(self, a):
+ return scipy.linalg.inv(a)
+
+ def sqrtm(self, a):
+ return scipy.linalg.sqrtm(a)
+
+ def isfinite(self, a):
+ return np.isfinite(a)
+
+ def array_equal(self, a, b):
+ return np.array_equal(a, b)
+
+ def is_floating_point(self, a):
+ return a.dtype.kind == "f"
+
class JaxBackend(Backend):
"""
@@ -1075,13 +1196,15 @@ class JaxBackend(Backend):
jax.device_put(jnp.array(1, dtype=jnp.float64), d)
]
- def to_numpy(self, a):
+ def _to_numpy(self, a):
return np.array(a)
def _change_device(self, a, type_as):
return jax.device_put(a, type_as.device_buffer.device())
- def from_numpy(self, a, type_as=None):
+ def _from_numpy(self, a, type_as=None):
+ if isinstance(a, float):
+ a = np.array(a)
if type_as is None:
return jnp.array(a)
else:
@@ -1216,6 +1339,9 @@ class JaxBackend(Backend):
def argmax(self, a, axis=None):
return jnp.argmax(a, axis=axis)
+ def argmin(self, a, axis=None):
+ return jnp.argmin(a, axis=axis)
+
def mean(self, a, axis=None):
return jnp.mean(a, axis=axis)
@@ -1235,7 +1361,7 @@ class JaxBackend(Backend):
return jnp.unique(a)
def logsumexp(self, a, axis=None):
- return jscipy.logsumexp(a, axis=axis)
+ return jspecial.logsumexp(a, axis=axis)
def stack(self, arrays, axis=0):
return jnp.stack(arrays, axis)
@@ -1293,8 +1419,11 @@ class JaxBackend(Backend):
# Currently, JAX does not support sparse matrices
return a
- def where(self, condition, x, y):
- return jnp.where(condition, x, y)
+ def where(self, condition, x=None, y=None):
+ if x is None and y is None:
+ return jnp.where(condition)
+ else:
+ return jnp.where(condition, x, y)
def copy(self, a):
# No need to copy, JAX arrays are immutable
@@ -1339,6 +1468,28 @@ class JaxBackend(Backend):
results[key] = (t1 - t0) / n_runs
return results
+ def solve(self, a, b):
+ return jnp.linalg.solve(a, b)
+
+ def trace(self, a):
+ return jnp.trace(a)
+
+ def inv(self, a):
+ return jnp.linalg.inv(a)
+
+ def sqrtm(self, a):
+ L, V = jnp.linalg.eigh(a)
+ return (V * jnp.sqrt(L)[None, :]) @ V.T
+
+ def isfinite(self, a):
+ return jnp.isfinite(a)
+
+ def array_equal(self, a, b):
+ return jnp.array_equal(a, b)
+
+ def is_floating_point(self, a):
+ return a.dtype.kind == "f"
+
class TorchBackend(Backend):
"""
@@ -1384,10 +1535,10 @@ class TorchBackend(Backend):
self.ValFunction = ValFunction
- def to_numpy(self, a):
+ def _to_numpy(self, a):
return a.cpu().detach().numpy()
- def from_numpy(self, a, type_as=None):
+ def _from_numpy(self, a, type_as=None):
if isinstance(a, float):
a = np.array(a)
if type_as is None:
@@ -1397,7 +1548,7 @@ class TorchBackend(Backend):
def set_gradients(self, val, inputs, grads):
- Func = self.ValFunction()
+ Func = self.ValFunction
res = Func.apply(val, grads, *inputs)
@@ -1564,6 +1715,9 @@ class TorchBackend(Backend):
def argmax(self, a, axis=None):
return torch.argmax(a, dim=axis)
+ def argmin(self, a, axis=None):
+ return torch.argmin(a, dim=axis)
+
def mean(self, a, axis=None):
if axis is not None:
return torch.mean(a, dim=axis)
@@ -1580,8 +1734,11 @@ class TorchBackend(Backend):
return torch.linspace(start, stop, num, dtype=torch.float64)
def meshgrid(self, a, b):
- X, Y = torch.meshgrid(a, b)
- return X.T, Y.T
+ try:
+ return torch.meshgrid(a, b, indexing="xy")
+ except TypeError:
+ X, Y = torch.meshgrid(a, b)
+ return X.T, Y.T
def diag(self, a, k=0):
return torch.diag(a, diagonal=k)
@@ -1659,8 +1816,11 @@ class TorchBackend(Backend):
else:
return a
- def where(self, condition, x, y):
- return torch.where(condition, x, y)
+ def where(self, condition, x=None, y=None):
+ if x is None and y is None:
+ return torch.where(condition)
+ else:
+ return torch.where(condition, x, y)
def copy(self, a):
return torch.clone(a)
@@ -1718,6 +1878,28 @@ class TorchBackend(Backend):
torch.cuda.empty_cache()
return results
+ def solve(self, a, b):
+ return torch.linalg.solve(a, b)
+
+ def trace(self, a):
+ return torch.trace(a)
+
+ def inv(self, a):
+ return torch.linalg.inv(a)
+
+ def sqrtm(self, a):
+ L, V = torch.linalg.eigh(a)
+ return (V * torch.sqrt(L)[None, :]) @ V.T
+
+ def isfinite(self, a):
+ return torch.isfinite(a)
+
+ def array_equal(self, a, b):
+ return torch.equal(a, b)
+
+ def is_floating_point(self, a):
+ return a.dtype.is_floating_point
+
class CupyBackend(Backend): # pragma: no cover
"""
@@ -1741,10 +1923,12 @@ class CupyBackend(Backend): # pragma: no cover
cp.array(1, dtype=cp.float64)
]
- def to_numpy(self, a):
+ def _to_numpy(self, a):
return cp.asnumpy(a)
- def from_numpy(self, a, type_as=None):
+ def _from_numpy(self, a, type_as=None):
+ if isinstance(a, float):
+ a = np.array(a)
if type_as is None:
return cp.asarray(a)
else:
@@ -1884,6 +2068,9 @@ class CupyBackend(Backend): # pragma: no cover
def argmax(self, a, axis=None):
return cp.argmax(a, axis=axis)
+ def argmin(self, a, axis=None):
+ return cp.argmin(a, axis=axis)
+
def mean(self, a, axis=None):
return cp.mean(a, axis=axis)
@@ -1982,8 +2169,11 @@ class CupyBackend(Backend): # pragma: no cover
else:
return a
- def where(self, condition, x, y):
- return cp.where(condition, x, y)
+ def where(self, condition, x=None, y=None):
+ if x is None and y is None:
+ return cp.where(condition)
+ else:
+ return cp.where(condition, x, y)
def copy(self, a):
return a.copy()
@@ -2035,6 +2225,28 @@ class CupyBackend(Backend): # pragma: no cover
pinned_mempool.free_all_blocks()
return results
+ def solve(self, a, b):
+ return cp.linalg.solve(a, b)
+
+ def trace(self, a):
+ return cp.trace(a)
+
+ def inv(self, a):
+ return cp.linalg.inv(a)
+
+ def sqrtm(self, a):
+ L, V = cp.linalg.eigh(a)
+ return (V * self.sqrt(L)[None, :]) @ V.T
+
+ def isfinite(self, a):
+ return cp.isfinite(a)
+
+ def array_equal(self, a, b):
+ return cp.array_equal(a, b)
+
+ def is_floating_point(self, a):
+ return a.dtype.kind == "f"
+
class TensorflowBackend(Backend):
@@ -2060,13 +2272,16 @@ class TensorflowBackend(Backend):
"To use TensorflowBackend, you need to activate the tensorflow "
"numpy API. You can activate it by running: \n"
"from tensorflow.python.ops.numpy_ops import np_config\n"
- "np_config.enable_numpy_behavior()"
+ "np_config.enable_numpy_behavior()",
+ stacklevel=2
)
- def to_numpy(self, a):
+ def _to_numpy(self, a):
return a.numpy()
- def from_numpy(self, a, type_as=None):
+ def _from_numpy(self, a, type_as=None):
+ if isinstance(a, float):
+ a = np.array(a)
if not isinstance(a, self.__type__):
if type_as is None:
return tf.convert_to_tensor(a)
@@ -2208,6 +2423,9 @@ class TensorflowBackend(Backend):
def argmax(self, a, axis=None):
return tnp.argmax(a, axis=axis)
+ def argmin(self, a, axis=None):
+ return tnp.argmin(a, axis=axis)
+
def mean(self, a, axis=None):
return tnp.mean(a, axis=axis)
@@ -2309,8 +2527,11 @@ class TensorflowBackend(Backend):
else:
return a
- def where(self, condition, x, y):
- return tnp.where(condition, x, y)
+ def where(self, condition, x=None, y=None):
+ if x is None and y is None:
+ return tnp.where(condition)
+ else:
+ return tnp.where(condition, x, y)
def copy(self, a):
return tf.identity(a)
@@ -2364,3 +2585,24 @@ class TensorflowBackend(Backend):
results[key] = (t1 - t0) / n_runs
return results
+
+ def solve(self, a, b):
+ return tf.linalg.solve(a, b)
+
+ def trace(self, a):
+ return tf.linalg.trace(a)
+
+ def inv(self, a):
+ return tf.linalg.inv(a)
+
+ def sqrtm(self, a):
+ return tf.linalg.sqrtm(a)
+
+ def isfinite(self, a):
+ return tnp.isfinite(a)
+
+ def array_equal(self, a, b):
+ return tnp.array_equal(a, b)
+
+ def is_floating_point(self, a):
+ return a.dtype.is_floating
diff --git a/ot/bregman.py b/ot/bregman.py
index fc20175..c06af2f 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -2525,8 +2525,7 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
# geometric interpolation
delta = nx.exp(alpha * nx.log(other) + (1 - alpha) * nx.log(inv_new))
K = projR(K, delta)
- K0 = nx.dot(nx.diag(nx.dot(D.T, delta / inv_new)), K0)
-
+ K0 = nx.dot(D.T, delta / inv_new)[:, None] * K0
err = nx.norm(nx.sum(K0, axis=1) - old)
old = new
if log:
@@ -2656,16 +2655,16 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
classes = nx.unique(Ys[d])
# build the corresponding D_1 and D_2 matrices
- Dtmp1 = nx.zeros((nbclasses, nsk), type_as=Xs[0])
- Dtmp2 = nx.zeros((nbclasses, nsk), type_as=Xs[0])
+ Dtmp1 = np.zeros((nbclasses, nsk))
+ Dtmp2 = np.zeros((nbclasses, nsk))
for c in classes:
- nbelemperclass = nx.sum(Ys[d] == c)
+ nbelemperclass = float(nx.sum(Ys[d] == c))
if nbelemperclass != 0:
- Dtmp1[int(c), Ys[d] == c] = 1.
- Dtmp2[int(c), Ys[d] == c] = 1. / (nbelemperclass)
- D1.append(Dtmp1)
- D2.append(Dtmp2)
+ Dtmp1[int(c), nx.to_numpy(Ys[d] == c)] = 1.
+ Dtmp2[int(c), nx.to_numpy(Ys[d] == c)] = 1. / (nbelemperclass)
+ D1.append(nx.from_numpy(Dtmp1, type_as=Xs[0]))
+ D2.append(nx.from_numpy(Dtmp2, type_as=Xs[0]))
# build the cost matrix and the Gibbs kernel
Mtmp = dist(Xs[d], Xt, metric=metric)
diff --git a/ot/da.py b/ot/da.py
index 841f31a..0b9737e 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -12,12 +12,12 @@ Domain adaptation with optimal transport
# License: MIT License
import numpy as np
-import scipy.linalg as linalg
+from .backend import get_backend
from .bregman import sinkhorn, jcpot_barycenter
from .lp import emd
from .utils import unif, dist, kernel, cost_normalization, label_normalization, laplacian, dots
-from .utils import check_params, BaseEstimator
+from .utils import list_to_array, check_params, BaseEstimator
from .unbalanced import sinkhorn_unbalanced
from .optim import cg
from .optim import gcg
@@ -60,13 +60,13 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
Parameters
----------
- a : np.ndarray (ns,)
+ a : array-like (ns,)
samples weights in the source domain
- labels_a : np.ndarray (ns,)
+ labels_a : array-like (ns,)
labels of samples in the source domain
- b : np.ndarray (nt,)
+ b : array-like (nt,)
samples weights in the target domain
- M : np.ndarray (ns,nt)
+ M : array-like (ns,nt)
loss matrix
reg : float
Regularization term for entropic regularization >0
@@ -86,7 +86,7 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
Returns
-------
- gamma : (ns, nt) ndarray
+ gamma : (ns, nt) array-like
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -111,26 +111,28 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
ot.optim.cg : General regularized OT
"""
+ a, labels_a, b, M = list_to_array(a, labels_a, b, M)
+ nx = get_backend(a, labels_a, b, M)
+
p = 0.5
epsilon = 1e-3
indices_labels = []
- classes = np.unique(labels_a)
+ classes = nx.unique(labels_a)
for c in classes:
- idxc, = np.where(labels_a == c)
+ idxc, = nx.where(labels_a == c)
indices_labels.append(idxc)
- W = np.zeros(M.shape)
-
+ W = nx.zeros(M.shape, type_as=M)
for cpt in range(numItermax):
Mreg = M + eta * W
transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
stopThr=stopInnerThr)
# the transport has been computed. Check if classes are really
# separated
- W = np.ones(M.shape)
+ W = nx.ones(M.shape, type_as=M)
for (i, c) in enumerate(classes):
- majs = np.sum(transp[indices_labels[i]], axis=0)
+ majs = nx.sum(transp[indices_labels[i]], axis=0)
majs = p * ((majs + epsilon) ** (p - 1))
W[indices_labels[i]] = majs
@@ -174,13 +176,13 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
Parameters
----------
- a : np.ndarray (ns,)
+ a : array-like (ns,)
samples weights in the source domain
- labels_a : np.ndarray (ns,)
+ labels_a : array-like (ns,)
labels of samples in the source domain
- b : np.ndarray (nt,)
+ b : array-like (nt,)
samples in the target domain
- M : np.ndarray (ns,nt)
+ M : array-like (ns,nt)
loss matrix
reg : float
Regularization term for entropic regularization >0
@@ -200,7 +202,7 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
Returns
-------
- gamma : (ns, nt) ndarray
+ gamma : (ns, nt) array-like
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -222,22 +224,25 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
ot.optim.gcg : Generalized conditional gradient for OT problems
"""
- lstlab = np.unique(labels_a)
+ a, labels_a, b, M = list_to_array(a, labels_a, b, M)
+ nx = get_backend(a, labels_a, b, M)
+
+ lstlab = nx.unique(labels_a)
def f(G):
res = 0
for i in range(G.shape[1]):
for lab in lstlab:
temp = G[labels_a == lab, i]
- res += np.linalg.norm(temp)
+ res += nx.norm(temp)
return res
def df(G):
- W = np.zeros(G.shape)
+ W = nx.zeros(G.shape, type_as=G)
for i in range(G.shape[1]):
for lab in lstlab:
temp = G[labels_a == lab, i]
- n = np.linalg.norm(temp)
+ n = nx.norm(temp)
if n:
W[labels_a == lab, i] = temp / n
return W
@@ -289,9 +294,9 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
Parameters
----------
- xs : np.ndarray (ns,d)
+ xs : array-like (ns,d)
samples in the source domain
- xt : np.ndarray (nt,d)
+ xt : array-like (nt,d)
samples in the target domain
mu : float,optional
Weight for the linear OT loss (>0)
@@ -315,9 +320,9 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
Returns
-------
- gamma : (ns, nt) ndarray
+ gamma : (ns, nt) array-like
Optimal transportation matrix for the given parameters
- L : (d, d) ndarray
+ L : (d, d) array-like
Linear mapping matrix ((:math:`d+1`, `d`) if bias)
log : dict
log dictionary return only if log==True in parameters
@@ -336,13 +341,15 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
ot.optim.cg : General regularized OT
"""
+ xs, xt = list_to_array(xs, xt)
+ nx = get_backend(xs, xt)
ns, nt, d = xs.shape[0], xt.shape[0], xt.shape[1]
if bias:
- xs1 = np.hstack((xs, np.ones((ns, 1))))
- xstxs = xs1.T.dot(xs1)
- Id = np.eye(d + 1)
+ xs1 = nx.concatenate((xs, nx.ones((ns, 1), type_as=xs)), axis=1)
+ xstxs = nx.dot(xs1.T, xs1)
+ Id = nx.eye(d + 1, type_as=xs)
Id[-1] = 0
I0 = Id[:, :-1]
@@ -350,8 +357,8 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
return x[:-1, :]
else:
xs1 = xs
- xstxs = xs1.T.dot(xs1)
- Id = np.eye(d)
+ xstxs = nx.dot(xs1.T, xs1)
+ Id = nx.eye(d, type_as=xs)
I0 = Id
def sel(x):
@@ -360,7 +367,8 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
if log:
log = {'err': []}
- a, b = unif(ns), unif(nt)
+ a = unif(ns, type_as=xs)
+ b = unif(nt, type_as=xt)
M = dist(xs, xt) * ns
G = emd(a, b, M)
@@ -368,23 +376,26 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
def loss(L, G):
"""Compute full loss"""
- return np.sum((xs1.dot(L) - ns * G.dot(xt)) ** 2) + mu * \
- np.sum(G * M) + eta * np.sum(sel(L - I0) ** 2)
+ return (
+ nx.sum((nx.dot(xs1, L) - ns * nx.dot(G, xt)) ** 2)
+ + mu * nx.sum(G * M)
+ + eta * nx.sum(sel(L - I0) ** 2)
+ )
def solve_L(G):
""" solve L problem with fixed G (least square)"""
- xst = ns * G.dot(xt)
- return np.linalg.solve(xstxs + eta * Id, xs1.T.dot(xst) + eta * I0)
+ xst = ns * nx.dot(G, xt)
+ return nx.solve(xstxs + eta * Id, nx.dot(xs1.T, xst) + eta * I0)
def solve_G(L, G0):
"""Update G with CG algorithm"""
- xsi = xs1.dot(L)
+ xsi = nx.dot(xs1, L)
def f(G):
- return np.sum((xsi - ns * G.dot(xt)) ** 2)
+ return nx.sum((xsi - ns * nx.dot(G, xt)) ** 2)
def df(G):
- return -2 * ns * (xsi - ns * G.dot(xt)).dot(xt.T)
+ return -2 * ns * nx.dot(xsi - ns * nx.dot(G, xt), xt.T)
G = cg(a, b, M, 1.0 / mu, f, df, G0=G0,
numItermax=numInnerItermax, stopThr=stopInnerThr)
@@ -481,9 +492,9 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
Parameters
----------
- xs : np.ndarray (ns,d)
+ xs : array-like (ns,d)
samples in the source domain
- xt : np.ndarray (nt,d)
+ xt : array-like (nt,d)
samples in the target domain
mu : float,optional
Weight for the linear OT loss (>0)
@@ -513,9 +524,9 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
Returns
-------
- gamma : (ns, nt) ndarray
+ gamma : (ns, nt) array-like
Optimal transportation matrix for the given parameters
- L : (ns, d) ndarray
+ L : (ns, d) array-like
Nonlinear mapping matrix ((:math:`n_s+1`, `d`) if bias)
log : dict
log dictionary return only if log==True in parameters
@@ -534,15 +545,17 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
ot.optim.cg : General regularized OT
"""
+ xs, xt = list_to_array(xs, xt)
+ nx = get_backend(xs, xt)
ns, nt = xs.shape[0], xt.shape[0]
K = kernel(xs, xs, method=kerneltype, sigma=sigma)
if bias:
- K1 = np.hstack((K, np.ones((ns, 1))))
- Id = np.eye(ns + 1)
+ K1 = nx.concatenate((K, nx.ones((ns, 1), type_as=xs)), axis=1)
+ Id = nx.eye(ns + 1, type_as=xs)
Id[-1] = 0
- Kp = np.eye(ns + 1)
+ Kp = nx.eye(ns + 1, type_as=xs)
Kp[:ns, :ns] = K
# ls regu
@@ -550,12 +563,12 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
# Kreg=I
# RKHS regul
- K0 = K1.T.dot(K1) + eta * Kp
+ K0 = nx.dot(K1.T, K1) + eta * Kp
Kreg = Kp
else:
K1 = K
- Id = np.eye(ns)
+ Id = nx.eye(ns, type_as=xs)
# ls regul
# K0 = K1.T.dot(K1)+eta*I
@@ -568,7 +581,8 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
if log:
log = {'err': []}
- a, b = unif(ns), unif(nt)
+ a = unif(ns, type_as=xs)
+ b = unif(nt, type_as=xt)
M = dist(xs, xt) * ns
G = emd(a, b, M)
@@ -576,28 +590,31 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
def loss(L, G):
"""Compute full loss"""
- return np.sum((K1.dot(L) - ns * G.dot(xt)) ** 2) + mu * \
- np.sum(G * M) + eta * np.trace(L.T.dot(Kreg).dot(L))
+ return (
+ nx.sum((nx.dot(K1, L) - ns * nx.dot(G, xt)) ** 2)
+ + mu * nx.sum(G * M)
+ + eta * nx.trace(dots(L.T, Kreg, L))
+ )
def solve_L_nobias(G):
""" solve L problem with fixed G (least square)"""
- xst = ns * G.dot(xt)
- return np.linalg.solve(K0, xst)
+ xst = ns * nx.dot(G, xt)
+ return nx.solve(K0, xst)
def solve_L_bias(G):
""" solve L problem with fixed G (least square)"""
- xst = ns * G.dot(xt)
- return np.linalg.solve(K0, K1.T.dot(xst))
+ xst = ns * nx.dot(G, xt)
+ return nx.solve(K0, nx.dot(K1.T, xst))
def solve_G(L, G0):
"""Update G with CG algorithm"""
- xsi = K1.dot(L)
+ xsi = nx.dot(K1, L)
def f(G):
- return np.sum((xsi - ns * G.dot(xt)) ** 2)
+ return nx.sum((xsi - ns * nx.dot(G, xt)) ** 2)
def df(G):
- return -2 * ns * (xsi - ns * G.dot(xt)).dot(xt.T)
+ return -2 * ns * nx.dot(xsi - ns * nx.dot(G, xt), xt.T)
G = cg(a, b, M, 1.0 / mu, f, df, G0=G0,
numItermax=numInnerItermax, stopThr=stopInnerThr)
@@ -681,15 +698,15 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
Parameters
----------
- xs : np.ndarray (ns,d)
+ xs : array-like (ns,d)
samples in the source domain
- xt : np.ndarray (nt,d)
+ xt : array-like (nt,d)
samples in the target domain
reg : float,optional
regularization added to the diagonals of covariances (>0)
- ws : np.ndarray (ns,1), optional
+ ws : array-like (ns,1), optional
weights for the source samples
- wt : np.ndarray (ns,1), optional
+ wt : array-like (ns,1), optional
weights for the target samples
bias: boolean, optional
estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True)
@@ -699,9 +716,9 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
Returns
-------
- A : (d, d) ndarray
+ A : (d, d) array-like
Linear operator
- b : (1, d) ndarray
+ b : (1, d) array-like
bias
log : dict
log dictionary return only if log==True in parameters
@@ -719,36 +736,38 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
"""
+ xs, xt = list_to_array(xs, xt)
+ nx = get_backend(xs, xt)
d = xs.shape[1]
if bias:
- mxs = xs.mean(0, keepdims=True)
- mxt = xt.mean(0, keepdims=True)
+ mxs = nx.mean(xs, axis=0)[None, :]
+ mxt = nx.mean(xt, axis=0)[None, :]
xs = xs - mxs
xt = xt - mxt
else:
- mxs = np.zeros((1, d))
- mxt = np.zeros((1, d))
+ mxs = nx.zeros((1, d), type_as=xs)
+ mxt = nx.zeros((1, d), type_as=xs)
if ws is None:
- ws = np.ones((xs.shape[0], 1)) / xs.shape[0]
+ ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0]
if wt is None:
- wt = np.ones((xt.shape[0], 1)) / xt.shape[0]
+ wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0]
- Cs = (xs * ws).T.dot(xs) / ws.sum() + reg * np.eye(d)
- Ct = (xt * wt).T.dot(xt) / wt.sum() + reg * np.eye(d)
+ Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) + reg * nx.eye(d, type_as=xs)
+ Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt)
- Cs12 = linalg.sqrtm(Cs)
- Cs_12 = linalg.inv(Cs12)
+ Cs12 = nx.sqrtm(Cs)
+ Cs_12 = nx.inv(Cs12)
- M0 = linalg.sqrtm(Cs12.dot(Ct.dot(Cs12)))
+ M0 = nx.sqrtm(dots(Cs12, Ct, Cs12))
- A = Cs_12.dot(M0.dot(Cs_12))
+ A = dots(Cs_12, M0, Cs_12)
- b = mxt - mxs.dot(A)
+ b = mxt - nx.dot(mxs, A)
if log:
log = {}
@@ -798,15 +817,15 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al
Parameters
----------
- a : np.ndarray (ns,)
+ a : array-like (ns,)
samples weights in the source domain
- b : np.ndarray (nt,)
+ b : array-like (nt,)
samples weights in the target domain
- xs : np.ndarray (ns,d)
+ xs : array-like (ns,d)
samples in the source domain
- xt : np.ndarray (nt,d)
+ xt : array-like (nt,d)
samples in the target domain
- M : np.ndarray (ns,nt)
+ M : array-like (ns,nt)
loss matrix
sim : string, optional
Type of similarity ('knn' or 'gauss') used to construct the Laplacian.
@@ -834,7 +853,7 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al
Returns
-------
- gamma : (ns, nt) ndarray
+ gamma : (ns, nt) array-like
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -862,9 +881,12 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al
raise ValueError(
'Similarity parameter should be an int or a float. Got {type} instead.'.format(type=type(sim_param).__name__))
+ a, b, xs, xt, M = list_to_array(a, b, xs, xt, M)
+ nx = get_backend(a, b, xs, xt, M)
+
if sim == 'gauss':
if sim_param is None:
- sim_param = 1 / (2 * (np.mean(dist(xs, xs, 'sqeuclidean')) ** 2))
+ sim_param = 1 / (2 * (nx.mean(dist(xs, xs, 'sqeuclidean')) ** 2))
sS = kernel(xs, xs, method=sim, sigma=sim_param)
sT = kernel(xt, xt, method=sim, sigma=sim_param)
@@ -874,9 +896,13 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al
from sklearn.neighbors import kneighbors_graph
- sS = kneighbors_graph(X=xs, n_neighbors=int(sim_param)).toarray()
+ sS = nx.from_numpy(kneighbors_graph(
+ X=nx.to_numpy(xs), n_neighbors=int(sim_param)
+ ).toarray(), type_as=xs)
sS = (sS + sS.T) / 2
- sT = kneighbors_graph(xt, n_neighbors=int(sim_param)).toarray()
+ sT = nx.from_numpy(kneighbors_graph(
+ X=nx.to_numpy(xt), n_neighbors=int(sim_param)
+ ).toarray(), type_as=xt)
sT = (sT + sT.T) / 2
else:
raise ValueError('Unknown similarity type {sim}. Currently supported similarity types are "knn" and "gauss".'.format(sim=sim))
@@ -885,12 +911,14 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al
lT = laplacian(sT)
def f(G):
- return alpha * np.trace(np.dot(xt.T, np.dot(G.T, np.dot(lS, np.dot(G, xt))))) \
- + (1 - alpha) * np.trace(np.dot(xs.T, np.dot(G, np.dot(lT, np.dot(G.T, xs)))))
+ return (
+ alpha * nx.trace(dots(xt.T, G.T, lS, G, xt))
+ + (1 - alpha) * nx.trace(dots(xs.T, G, lT, G.T, xs))
+ )
ls2 = lS + lS.T
lt2 = lT + lT.T
- xt2 = np.dot(xt, xt.T)
+ xt2 = nx.dot(xt, xt.T)
if reg == 'disp':
Cs = -eta * alpha / xs.shape[0] * dots(ls2, xs, xt.T)
@@ -898,8 +926,10 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al
M = M + Cs + Ct
def df(G):
- return alpha * np.dot(ls2, np.dot(G, xt2))\
- + (1 - alpha) * np.dot(xs, np.dot(xs.T, np.dot(G, lt2)))
+ return (
+ alpha * dots(ls2, G, xt2)
+ + (1 - alpha) * dots(xs, xs.T, G, lt2)
+ )
return cg(a, b, M, reg=eta, f=f, df=df, G0=None, numItermax=numItermax, numItermaxEmd=numInnerItermax,
stopThr=stopThr, stopThr2=stopInnerThr, verbose=verbose, log=log)
@@ -919,7 +949,7 @@ def distribution_estimation_uniform(X):
The uniform distribution estimated from :math:`\mathbf{X}`
"""
- return unif(X.shape[0])
+ return unif(X.shape[0], type_as=X)
class BaseTransport(BaseEstimator):
@@ -973,6 +1003,7 @@ class BaseTransport(BaseEstimator):
self : object
Returns self.
"""
+ nx = self._get_backend(Xs, ys, Xt, yt)
# check the necessary inputs parameters are here
if check_params(Xs=Xs, Xt=Xt):
@@ -984,14 +1015,14 @@ class BaseTransport(BaseEstimator):
if (ys is not None) and (yt is not None):
if self.limit_max != np.infty:
- self.limit_max = self.limit_max * np.max(self.cost_)
+ self.limit_max = self.limit_max * nx.max(self.cost_)
# assumes labeled source samples occupy the first rows
# and labeled target samples occupy the first columns
- classes = [c for c in np.unique(ys) if c != -1]
+ classes = [c for c in nx.unique(ys) if c != -1]
for c in classes:
- idx_s = np.where((ys != c) & (ys != -1))
- idx_t = np.where(yt == c)
+ idx_s = nx.where((ys != c) & (ys != -1))
+ idx_t = nx.where(yt == c)
# all the coefficients corresponding to a source sample
# and a target sample :
@@ -1062,23 +1093,24 @@ class BaseTransport(BaseEstimator):
transp_Xs : array-like, shape (n_source_samples, n_features)
The transport source samples.
"""
+ nx = self.nx
# check the necessary inputs parameters are here
if check_params(Xs=Xs):
- if np.array_equal(self.xs_, Xs):
+ if nx.array_equal(self.xs_, Xs):
# perform standard barycentric mapping
- transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None]
+ transp = self.coupling_ / nx.sum(self.coupling_, axis=1)[:, None]
# set nans to 0
- transp[~ np.isfinite(transp)] = 0
+ transp[~ nx.isfinite(transp)] = 0
# compute transported samples
- transp_Xs = np.dot(transp, self.xt_)
+ transp_Xs = nx.dot(transp, self.xt_)
else:
# perform out of sample mapping
- indices = np.arange(Xs.shape[0])
+ indices = nx.arange(Xs.shape[0])
batch_ind = [
indices[i:i + batch_size]
for i in range(0, len(indices), batch_size)]
@@ -1087,20 +1119,20 @@ class BaseTransport(BaseEstimator):
for bi in batch_ind:
# get the nearest neighbor in the source domain
D0 = dist(Xs[bi], self.xs_)
- idx = np.argmin(D0, axis=1)
+ idx = nx.argmin(D0, axis=1)
# transport the source samples
- transp = self.coupling_ / np.sum(
- self.coupling_, 1)[:, None]
- transp[~ np.isfinite(transp)] = 0
- transp_Xs_ = np.dot(transp, self.xt_)
+ transp = self.coupling_ / nx.sum(
+ self.coupling_, axis=1)[:, None]
+ transp[~ nx.isfinite(transp)] = 0
+ transp_Xs_ = nx.dot(transp, self.xt_)
# define the transported points
transp_Xs_ = transp_Xs_[idx, :] + Xs[bi] - self.xs_[idx, :]
transp_Xs.append(transp_Xs_)
- transp_Xs = np.concatenate(transp_Xs, axis=0)
+ transp_Xs = nx.concatenate(transp_Xs, axis=0)
return transp_Xs
@@ -1127,26 +1159,27 @@ class BaseTransport(BaseEstimator):
International Conference on Artificial Intelligence and Statistics (AISTATS), 2019.
"""
+ nx = self.nx
# check the necessary inputs parameters are here
if check_params(ys=ys):
- ysTemp = label_normalization(np.copy(ys))
- classes = np.unique(ysTemp)
+ ysTemp = label_normalization(nx.copy(ys))
+ classes = nx.unique(ysTemp)
n = len(classes)
- D1 = np.zeros((n, len(ysTemp)))
+ D1 = nx.zeros((n, len(ysTemp)), type_as=self.coupling_)
# perform label propagation
- transp = self.coupling_ / np.sum(self.coupling_, 0, keepdims=True)
+ transp = self.coupling_ / nx.sum(self.coupling_, axis=0)[None, :]
# set nans to 0
- transp[~ np.isfinite(transp)] = 0
+ transp[~ nx.isfinite(transp)] = 0
for c in classes:
D1[int(c), ysTemp == c] = 1
# compute propagated labels
- transp_ys = np.dot(D1, transp)
+ transp_ys = nx.dot(D1, transp)
return transp_ys.T
@@ -1176,23 +1209,24 @@ class BaseTransport(BaseEstimator):
transp_Xt : array-like, shape (n_source_samples, n_features)
The transported target samples.
"""
+ nx = self.nx
# check the necessary inputs parameters are here
if check_params(Xt=Xt):
- if np.array_equal(self.xt_, Xt):
+ if nx.array_equal(self.xt_, Xt):
# perform standard barycentric mapping
- transp_ = self.coupling_.T / np.sum(self.coupling_, 0)[:, None]
+ transp_ = self.coupling_.T / nx.sum(self.coupling_, 0)[:, None]
# set nans to 0
- transp_[~ np.isfinite(transp_)] = 0
+ transp_[~ nx.isfinite(transp_)] = 0
# compute transported samples
- transp_Xt = np.dot(transp_, self.xs_)
+ transp_Xt = nx.dot(transp_, self.xs_)
else:
# perform out of sample mapping
- indices = np.arange(Xt.shape[0])
+ indices = nx.arange(Xt.shape[0])
batch_ind = [
indices[i:i + batch_size]
for i in range(0, len(indices), batch_size)]
@@ -1200,20 +1234,20 @@ class BaseTransport(BaseEstimator):
transp_Xt = []
for bi in batch_ind:
D0 = dist(Xt[bi], self.xt_)
- idx = np.argmin(D0, axis=1)
+ idx = nx.argmin(D0, axis=1)
# transport the target samples
- transp_ = self.coupling_.T / np.sum(
+ transp_ = self.coupling_.T / nx.sum(
self.coupling_, 0)[:, None]
- transp_[~ np.isfinite(transp_)] = 0
- transp_Xt_ = np.dot(transp_, self.xs_)
+ transp_[~ nx.isfinite(transp_)] = 0
+ transp_Xt_ = nx.dot(transp_, self.xs_)
# define the transported points
transp_Xt_ = transp_Xt_[idx, :] + Xt[bi] - self.xt_[idx, :]
transp_Xt.append(transp_Xt_)
- transp_Xt = np.concatenate(transp_Xt, axis=0)
+ transp_Xt = nx.concatenate(transp_Xt, axis=0)
return transp_Xt
@@ -1230,26 +1264,27 @@ class BaseTransport(BaseEstimator):
transp_ys : array-like, shape (n_source_samples, nb_classes)
Estimated soft source labels.
"""
+ nx = self.nx
# check the necessary inputs parameters are here
if check_params(yt=yt):
- ytTemp = label_normalization(np.copy(yt))
- classes = np.unique(ytTemp)
+ ytTemp = label_normalization(nx.copy(yt))
+ classes = nx.unique(ytTemp)
n = len(classes)
- D1 = np.zeros((n, len(ytTemp)))
+ D1 = nx.zeros((n, len(ytTemp)), type_as=self.coupling_)
# perform label propagation
- transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None]
+ transp = self.coupling_ / nx.sum(self.coupling_, 1)[:, None]
# set nans to 0
- transp[~ np.isfinite(transp)] = 0
+ transp[~ nx.isfinite(transp)] = 0
for c in classes:
D1[int(c), ytTemp == c] = 1
# compute propagated samples
- transp_ys = np.dot(D1, transp.T)
+ transp_ys = nx.dot(D1, transp.T)
return transp_ys.T
@@ -1330,14 +1365,15 @@ class LinearTransport(BaseTransport):
self : object
Returns self.
"""
+ nx = self._get_backend(Xs, ys, Xt, yt)
self.mu_s = self.distribution_estimation(Xs)
self.mu_t = self.distribution_estimation(Xt)
# coupling estimation
returned_ = OT_mapping_linear(Xs, Xt, reg=self.reg,
- ws=self.mu_s.reshape((-1, 1)),
- wt=self.mu_t.reshape((-1, 1)),
+ ws=nx.reshape(self.mu_s, (-1, 1)),
+ wt=nx.reshape(self.mu_t, (-1, 1)),
bias=self.bias, log=self.log)
# deal with the value of log
@@ -1348,8 +1384,8 @@ class LinearTransport(BaseTransport):
self.log_ = dict()
# re compute inverse mapping
- self.A1_ = linalg.inv(self.A_)
- self.B1_ = -self.B_.dot(self.A1_)
+ self.A1_ = nx.inv(self.A_)
+ self.B1_ = -nx.dot(self.B_, self.A1_)
return self
@@ -1378,10 +1414,11 @@ class LinearTransport(BaseTransport):
transp_Xs : array-like, shape (n_source_samples, n_features)
The transport source samples.
"""
+ nx = self.nx
# check the necessary inputs parameters are here
if check_params(Xs=Xs):
- transp_Xs = Xs.dot(self.A_) + self.B_
+ transp_Xs = nx.dot(Xs, self.A_) + self.B_
return transp_Xs
@@ -1411,10 +1448,11 @@ class LinearTransport(BaseTransport):
transp_Xt : array-like, shape (n_source_samples, n_features)
The transported target samples.
"""
+ nx = self.nx
# check the necessary inputs parameters are here
if check_params(Xt=Xt):
- transp_Xt = Xt.dot(self.A1_) + self.B1_
+ transp_Xt = nx.dot(Xt, self.A1_) + self.B1_
return transp_Xt
@@ -2112,6 +2150,7 @@ class MappingTransport(BaseEstimator):
self : object
Returns self
"""
+ self._get_backend(Xs, ys, Xt, yt)
# check the necessary inputs parameters are here
if check_params(Xs=Xs, Xt=Xt):
@@ -2158,19 +2197,20 @@ class MappingTransport(BaseEstimator):
transp_Xs : array-like, shape (n_source_samples, n_features)
The transport source samples.
"""
+ nx = self.nx
# check the necessary inputs parameters are here
if check_params(Xs=Xs):
- if np.array_equal(self.xs_, Xs):
+ if nx.array_equal(self.xs_, Xs):
# perform standard barycentric mapping
- transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None]
+ transp = self.coupling_ / nx.sum(self.coupling_, 1)[:, None]
# set nans to 0
- transp[~ np.isfinite(transp)] = 0
+ transp[~ nx.isfinite(transp)] = 0
# compute transported samples
- transp_Xs = np.dot(transp, self.xt_)
+ transp_Xs = nx.dot(transp, self.xt_)
else:
if self.kernel == "gaussian":
K = kernel(Xs, self.xs_, method=self.kernel,
@@ -2178,8 +2218,10 @@ class MappingTransport(BaseEstimator):
elif self.kernel == "linear":
K = Xs
if self.bias:
- K = np.hstack((K, np.ones((Xs.shape[0], 1))))
- transp_Xs = K.dot(self.mapping_)
+ K = nx.concatenate(
+ [K, nx.ones((Xs.shape[0], 1), type_as=K)], axis=1
+ )
+ transp_Xs = nx.dot(K, self.mapping_)
return transp_Xs
@@ -2396,6 +2438,7 @@ class JCPOTTransport(BaseTransport):
self : object
Returns self.
"""
+ self._get_backend(*Xs, *ys, Xt, yt)
# check the necessary inputs parameters are here
if check_params(Xs=Xs, Xt=Xt, ys=ys):
@@ -2438,28 +2481,29 @@ class JCPOTTransport(BaseTransport):
batch_size : int, optional (default=128)
The batch size for out of sample inverse transform
"""
+ nx = self.nx
transp_Xs = []
# check the necessary inputs parameters are here
if check_params(Xs=Xs):
- if all([np.allclose(x, y) for x, y in zip(self.xs_, Xs)]):
+ if all([nx.allclose(x, y) for x, y in zip(self.xs_, Xs)]):
# perform standard barycentric mapping for each source domain
for coupling in self.coupling_:
- transp = coupling / np.sum(coupling, 1)[:, None]
+ transp = coupling / nx.sum(coupling, 1)[:, None]
# set nans to 0
- transp[~ np.isfinite(transp)] = 0
+ transp[~ nx.isfinite(transp)] = 0
# compute transported samples
- transp_Xs.append(np.dot(transp, self.xt_))
+ transp_Xs.append(nx.dot(transp, self.xt_))
else:
# perform out of sample mapping
- indices = np.arange(Xs.shape[0])
+ indices = nx.arange(Xs.shape[0])
batch_ind = [
indices[i:i + batch_size]
for i in range(0, len(indices), batch_size)]
@@ -2470,23 +2514,22 @@ class JCPOTTransport(BaseTransport):
transp_Xs_ = []
# get the nearest neighbor in the sources domains
- xs = np.concatenate(self.xs_, axis=0)
- idx = np.argmin(dist(Xs[bi], xs), axis=1)
+ xs = nx.concatenate(self.xs_, axis=0)
+ idx = nx.argmin(dist(Xs[bi], xs), axis=1)
# transport the source samples
for coupling in self.coupling_:
- transp = coupling / np.sum(
- coupling, 1)[:, None]
- transp[~ np.isfinite(transp)] = 0
- transp_Xs_.append(np.dot(transp, self.xt_))
+ transp = coupling / nx.sum(coupling, 1)[:, None]
+ transp[~ nx.isfinite(transp)] = 0
+ transp_Xs_.append(nx.dot(transp, self.xt_))
- transp_Xs_ = np.concatenate(transp_Xs_, axis=0)
+ transp_Xs_ = nx.concatenate(transp_Xs_, axis=0)
# define the transported points
transp_Xs_ = transp_Xs_[idx, :] + Xs[bi] - xs[idx, :]
transp_Xs.append(transp_Xs_)
- transp_Xs = np.concatenate(transp_Xs, axis=0)
+ transp_Xs = nx.concatenate(transp_Xs, axis=0)
return transp_Xs
@@ -2512,32 +2555,36 @@ class JCPOTTransport(BaseTransport):
"Optimal transport for multi-source domain adaptation under target shift",
International Conference on Artificial Intelligence and Statistics (AISTATS), 2019.
"""
+ nx = self.nx
# check the necessary inputs parameters are here
if check_params(ys=ys):
- yt = np.zeros((len(np.unique(np.concatenate(ys))), self.xt_.shape[0]))
+ yt = nx.zeros(
+ (len(nx.unique(nx.concatenate(ys))), self.xt_.shape[0]),
+ type_as=ys[0]
+ )
for i in range(len(ys)):
- ysTemp = label_normalization(np.copy(ys[i]))
- classes = np.unique(ysTemp)
+ ysTemp = label_normalization(nx.copy(ys[i]))
+ classes = nx.unique(ysTemp)
n = len(classes)
ns = len(ysTemp)
# perform label propagation
- transp = self.coupling_[i] / np.sum(self.coupling_[i], 1)[:, None]
+ transp = self.coupling_[i] / nx.sum(self.coupling_[i], 1)[:, None]
# set nans to 0
- transp[~ np.isfinite(transp)] = 0
+ transp[~ nx.isfinite(transp)] = 0
if self.log:
D1 = self.log_['D1'][i]
else:
- D1 = np.zeros((n, ns))
+ D1 = nx.zeros((n, ns), type_as=transp)
for c in classes:
D1[int(c), ysTemp == c] = 1
# compute propagated labels
- yt = yt + np.dot(D1, transp) / len(ys)
+ yt = yt + nx.dot(D1, transp) / len(ys)
return yt.T
@@ -2555,14 +2602,15 @@ class JCPOTTransport(BaseTransport):
transp_ys : list of K array-like objects, shape K x (nk_source_samples, nb_classes)
A list of estimated soft source labels
"""
+ nx = self.nx
# check the necessary inputs parameters are here
if check_params(yt=yt):
transp_ys = []
- ytTemp = label_normalization(np.copy(yt))
- classes = np.unique(ytTemp)
+ ytTemp = label_normalization(nx.copy(yt))
+ classes = nx.unique(ytTemp)
n = len(classes)
- D1 = np.zeros((n, len(ytTemp)))
+ D1 = nx.zeros((n, len(ytTemp)), type_as=self.coupling_[0])
for c in classes:
D1[int(c), ytTemp == c] = 1
@@ -2570,12 +2618,12 @@ class JCPOTTransport(BaseTransport):
for i in range(len(self.xs_)):
# perform label propagation
- transp = self.coupling_[i] / np.sum(self.coupling_[i], 1)[:, None]
+ transp = self.coupling_[i] / nx.sum(self.coupling_[i], 1)[:, None]
# set nans to 0
- transp[~ np.isfinite(transp)] = 0
+ transp[~ nx.isfinite(transp)] = 0
# compute propagated labels
- transp_ys.append(np.dot(D1, transp.T).T)
+ transp_ys.append(nx.dot(D1, transp.T).T)
return transp_ys
diff --git a/ot/dr.py b/ot/dr.py
index 1671ca0..0955c55 100644
--- a/ot/dr.py
+++ b/ot/dr.py
@@ -11,6 +11,7 @@ Dimension reduction with OT
# Author: Remi Flamary <remi.flamary@unice.fr>
# Minhui Huang <mhhuang@ucdavis.edu>
+# Jakub Zadrozny <jakub.r.zadrozny@gmail.com>
#
# License: MIT License
@@ -43,6 +44,28 @@ def sinkhorn(w1, w2, M, reg, k):
return G
+def logsumexp(M, axis):
+ r"""Log-sum-exp reduction compatible with autograd (no numpy implementation)
+ """
+ amax = np.amax(M, axis=axis, keepdims=True)
+ return np.log(np.sum(np.exp(M - amax), axis=axis)) + np.squeeze(amax, axis=axis)
+
+
+def sinkhorn_log(w1, w2, M, reg, k):
+ r"""Sinkhorn algorithm in log-domain with fixed number of iteration (autograd)
+ """
+ Mr = -M / reg
+ ui = np.zeros((M.shape[0],))
+ vi = np.zeros((M.shape[1],))
+ log_w1 = np.log(w1)
+ log_w2 = np.log(w2)
+ for i in range(k):
+ vi = log_w2 - logsumexp(Mr + ui[:, None], 0)
+ ui = log_w1 - logsumexp(Mr + vi[None, :], 1)
+ G = np.exp(ui[:, None] + Mr + vi[None, :])
+ return G
+
+
def split_classes(X, y):
r"""split samples in :math:`\mathbf{X}` by classes in :math:`\mathbf{y}`
"""
@@ -110,7 +133,7 @@ def fda(X, y, p=2, reg=1e-16):
return Popt, proj
-def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, normalize=False):
+def wda(X, y, p=2, reg=1, k=10, solver=None, sinkhorn_method='sinkhorn', maxiter=100, verbose=0, P0=None, normalize=False):
r"""
Wasserstein Discriminant Analysis :ref:`[11] <references-wda>`
@@ -126,6 +149,14 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no
- :math:`W` is entropic regularized Wasserstein distances
- :math:`\mathbf{X}^i` are samples in the dataset corresponding to class i
+ **Choosing a Sinkhorn solver**
+
+ By default and when using a regularization parameter that is not too small
+ the default sinkhorn solver should be enough. If you need to use a small
+ regularization to get sparse cost matrices, you should use the
+ :py:func:`ot.dr.sinkhorn_log` solver that will avoid numerical
+ errors, but can be slow in practice.
+
Parameters
----------
X : ndarray, shape (n, d)
@@ -139,6 +170,8 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no
solver : None | str, optional
None for steepest descent or 'TrustRegions' for trust regions algorithm
else should be a pymanopt.solvers
+ sinkhorn_method : str
+ method used for the Sinkhorn solver, either 'sinkhorn' or 'sinkhorn_log'
P0 : ndarray, shape (d, p)
Initial starting point for projection.
normalize : bool, optional
@@ -161,6 +194,13 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no
Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063.
""" # noqa
+ if sinkhorn_method.lower() == 'sinkhorn':
+ sinkhorn_solver = sinkhorn
+ elif sinkhorn_method.lower() == 'sinkhorn_log':
+ sinkhorn_solver = sinkhorn_log
+ else:
+ raise ValueError("Unknown Sinkhorn method '%s'." % sinkhorn_method)
+
mx = np.mean(X)
X -= mx.reshape((1, -1))
@@ -193,7 +233,7 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no
for j, xj in enumerate(xc[i:]):
xj = np.dot(xj, P)
M = dist(xi, xj)
- G = sinkhorn(wc[i], wc[j + i], M, reg * regmean[i, j], k)
+ G = sinkhorn_solver(wc[i], wc[j + i], M, reg * regmean[i, j], k)
if j == 0:
loss_w += np.sum(G * M)
else:
diff --git a/ot/factored.py b/ot/factored.py
new file mode 100644
index 0000000..abc2445
--- /dev/null
+++ b/ot/factored.py
@@ -0,0 +1,145 @@
+"""
+Factored OT solvers (low rank, cost or OT plan)
+"""
+
+# Author: Remi Flamary <remi.flamary@polytehnique.edu>
+#
+# License: MIT License
+
+from .backend import get_backend
+from .utils import dist
+from .lp import emd
+from .bregman import sinkhorn
+
+__all__ = ['factored_optimal_transport']
+
+
+def factored_optimal_transport(Xa, Xb, a=None, b=None, reg=0.0, r=100, X0=None, stopThr=1e-7, numItermax=100, verbose=False, log=False, **kwargs):
+ r"""Solves factored OT problem and return OT plans and intermediate distribution
+
+ This function solve the following OT problem [40]_
+
+ .. math::
+ \mathop{\arg \min}_\mu \quad W_2^2(\mu_a,\mu)+ W_2^2(\mu,\mu_b)
+
+ where :
+
+ - :math:`\mu_a` and :math:`\mu_b` are empirical distributions.
+ - :math:`\mu` is an empirical distribution with r samples
+
+ And returns the two OT plans between
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
+ Uses the conditional gradient algorithm to solve the problem proposed in
+ :ref:`[39] <references-weak>`.
+
+ Parameters
+ ----------
+ Xa : (ns,d) array-like, float
+ Source samples
+ Xb : (nt,d) array-like, float
+ Target samples
+ a : (ns,) array-like, float
+ Source histogram (uniform weight if empty list)
+ b : (nt,) array-like, float
+ Target histogram (uniform weight if empty list))
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on the relative variation (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ Ga: array-like, shape (ns, r)
+ Optimal transportation matrix between source and the intermediate
+ distribution
+ Gb: array-like, shape (r, nt)
+ Optimal transportation matrix between the intermediate and target
+ distribution
+ X: array-like, shape (r, d)
+ Support of the intermediate distribution
+ log: dict, optional
+ If input log is true, a dictionary containing the cost and dual
+ variables and exit status
+
+
+ .. _references-factored:
+ References
+ ----------
+ .. [40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger,
+ G., & Weed, J. (2019, April). Statistical optimal transport via factored
+ couplings. In The 22nd International Conference on Artificial
+ Intelligence and Statistics (pp. 2454-2465). PMLR.
+
+ See Also
+ --------
+ ot.bregman.sinkhorn : Entropic regularized OT ot.optim.cg : General
+ regularized OT
+ """
+
+ nx = get_backend(Xa, Xb)
+
+ n_a = Xa.shape[0]
+ n_b = Xb.shape[0]
+ d = Xa.shape[1]
+
+ if a is None:
+ a = nx.ones((n_a), type_as=Xa) / n_a
+ if b is None:
+ b = nx.ones((n_b), type_as=Xb) / n_b
+
+ if X0 is None:
+ X = nx.randn(r, d, type_as=Xa)
+ else:
+ X = X0
+
+ w = nx.ones(r, type_as=Xa) / r
+
+ def solve_ot(X1, X2, w1, w2):
+ M = dist(X1, X2)
+ if reg > 0:
+ G, log = sinkhorn(w1, w2, M, reg, log=True, **kwargs)
+ log['cost'] = nx.sum(G * M)
+ return G, log
+ else:
+ return emd(w1, w2, M, log=True, **kwargs)
+
+ norm_delta = []
+
+ # solve the barycenter
+ for i in range(numItermax):
+
+ old_X = X
+
+ # solve OT with template
+ Ga, loga = solve_ot(Xa, X, a, w)
+ Gb, logb = solve_ot(X, Xb, w, b)
+
+ X = 0.5 * (nx.dot(Ga.T, Xa) + nx.dot(Gb, Xb)) * r
+
+ delta = nx.norm(X - old_X)
+ if delta < stopThr:
+ break
+ if log:
+ norm_delta.append(delta)
+
+ if log:
+ log_dic = {'delta_iter': norm_delta,
+ 'ua': loga['u'],
+ 'va': loga['v'],
+ 'ub': logb['u'],
+ 'vb': logb['v'],
+ 'costa': loga['cost'],
+ 'costb': logb['cost'],
+ }
+ return Ga, Gb, X, log_dic
+
+ return Ga, Gb, X
diff --git a/ot/gpu/__init__.py b/ot/gpu/__init__.py
deleted file mode 100644
index 12db605..0000000
--- a/ot/gpu/__init__.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-GPU implementation for several OT solvers and utility
-functions.
-
-The GPU backend in handled by `cupy
-<https://cupy.chainer.org/>`_.
-
-.. warning::
- This module is now deprecated and will be removed in future releases. POT
- now privides a backend mechanism that allows for solving prolem on GPU wth
- the pytorch backend.
-
-
-.. warning::
- Note that by default the module is not imported in :mod:`ot`. In order to
- use it you need to explicitely import :mod:`ot.gpu` .
-
-By default, the functions in this module accept and return numpy arrays
-in order to proide drop-in replacement for the other POT function but
-the transfer between CPU en GPU comes with a significant overhead.
-
-In order to get the best performances, we recommend to give only cupy
-arrays to the functions and desactivate the conversion to numpy of the
-result of the function with parameter ``to_numpy=False``.
-
-"""
-
-# Author: Remi Flamary <remi.flamary@unice.fr>
-# Leo Gautheron <https://github.com/aje>
-#
-# License: MIT License
-
-import warnings
-
-from . import bregman
-from . import da
-from .bregman import sinkhorn
-from .da import sinkhorn_lpl1_mm
-
-from . import utils
-from .utils import dist, to_gpu, to_np
-
-
-warnings.warn('This module is deprecated and will be removed in the next minor release of POT', category=DeprecationWarning)
-
-
-__all__ = ["utils", "dist", "sinkhorn",
- "sinkhorn_lpl1_mm", 'bregman', 'da', 'to_gpu', 'to_np']
-
diff --git a/ot/gpu/bregman.py b/ot/gpu/bregman.py
deleted file mode 100644
index 76af00e..0000000
--- a/ot/gpu/bregman.py
+++ /dev/null
@@ -1,196 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-Bregman projections for regularized OT with GPU
-"""
-
-# Author: Remi Flamary <remi.flamary@unice.fr>
-# Leo Gautheron <https://github.com/aje>
-#
-# License: MIT License
-
-import cupy as np # np used for matrix computation
-import cupy as cp # cp used for cupy specific operations
-from . import utils
-
-
-def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
- verbose=False, log=False, to_numpy=True, **kwargs):
- r"""
- Solve the entropic regularization optimal transport on GPU
-
- If the input matrix are in numpy format, they will be uploaded to the
- GPU first which can incur significant time overhead.
-
- The function solves the following optimization problem:
-
- .. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
-
- s.t. \gamma 1 = a
-
- \gamma^T 1= b
-
- \gamma\geq 0
- where :
-
- - M is the (ns,nt) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (sum to 1)
-
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
-
-
- Parameters
- ----------
- a : np.ndarray (ns,)
- samples weights in the source domain
- b : np.ndarray (nt,) or np.ndarray (nt,nbb)
- samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : np.ndarray (ns,nt)
- loss matrix
- reg : float
- Regularization term >0
- numItermax : int, optional
- Max number of iterations
- stopThr : float, optional
- Stop threshold on error (>0)
- verbose : bool, optional
- Print information along iterations
- log : bool, optional
- record log if True
- to_numpy : boolean, optional (default True)
- If true convert back the GPU array result to numpy format.
-
-
- Returns
- -------
- gamma : (ns x nt) ndarray
- Optimal transportation matrix for the given parameters
- log : dict
- log dictionary return only if log==True in parameters
-
-
- References
- ----------
-
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
-
-
- See Also
- --------
- ot.lp.emd : Unregularized OT
- ot.optim.cg : General regularized OT
-
- """
-
- a = cp.asarray(a)
- b = cp.asarray(b)
- M = cp.asarray(M)
-
- if len(a) == 0:
- a = np.ones((M.shape[0],)) / M.shape[0]
- if len(b) == 0:
- b = np.ones((M.shape[1],)) / M.shape[1]
-
- # init data
- Nini = len(a)
- Nfin = len(b)
-
- if len(b.shape) > 1:
- nbb = b.shape[1]
- else:
- nbb = 0
-
- if log:
- log = {'err': []}
-
- # we assume that no distances are null except those of the diagonal of
- # distances
- if nbb:
- u = np.ones((Nini, nbb)) / Nini
- v = np.ones((Nfin, nbb)) / Nfin
- else:
- u = np.ones(Nini) / Nini
- v = np.ones(Nfin) / Nfin
-
- # print(reg)
-
- # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
- K = np.empty(M.shape, dtype=M.dtype)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
-
- # print(np.min(K))
- tmp2 = np.empty(b.shape, dtype=M.dtype)
-
- Kp = (1 / a).reshape(-1, 1) * K
- cpt = 0
- err = 1
- while (err > stopThr and cpt < numItermax):
- uprev = u
- vprev = v
-
- KtransposeU = np.dot(K.T, u)
- v = np.divide(b, KtransposeU)
- u = 1. / np.dot(Kp, v)
-
- if (np.any(KtransposeU == 0) or
- np.any(np.isnan(u)) or np.any(np.isnan(v)) or
- np.any(np.isinf(u)) or np.any(np.isinf(v))):
- # we have reached the machine precision
- # come back to previous solution and quit loop
- print('Warning: numerical errors at iteration', cpt)
- u = uprev
- v = vprev
- break
- if cpt % 10 == 0:
- # we can speed up the process by checking for the error only all
- # the 10th iterations
- if nbb:
- err = np.sqrt(
- np.sum((u - uprev)**2) / np.sum((u)**2)
- + np.sum((v - vprev)**2) / np.sum((v)**2)
- )
- else:
- # compute right marginal tmp2= (diag(u)Kdiag(v))^T1
- tmp2 = np.sum(u[:, None] * K * v[None, :], 0)
- #tmp2=np.einsum('i,ij,j->j', u, K, v)
- err = np.linalg.norm(tmp2 - b) # violation of marginal
- if log:
- log['err'].append(err)
-
- if verbose:
- if cpt % 200 == 0:
- print(
- '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
- cpt = cpt + 1
- if log:
- log['u'] = u
- log['v'] = v
-
- if nbb: # return only loss
- #res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) (explodes cupy memory)
- res = np.empty(nbb)
- for i in range(nbb):
- res[i] = np.sum(u[:, None, i] * (K * M) * v[None, :, i])
- if to_numpy:
- res = utils.to_np(res)
- if log:
- return res, log
- else:
- return res
-
- else: # return OT matrix
- res = u.reshape((-1, 1)) * K * v.reshape((1, -1))
- if to_numpy:
- res = utils.to_np(res)
- if log:
- return res, log
- else:
- return res
-
-
-# define sinkhorn as sinkhorn_knopp
-sinkhorn = sinkhorn_knopp
diff --git a/ot/gpu/da.py b/ot/gpu/da.py
deleted file mode 100644
index 7adb830..0000000
--- a/ot/gpu/da.py
+++ /dev/null
@@ -1,144 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-Domain adaptation with optimal transport with GPU implementation
-"""
-
-# Author: Remi Flamary <remi.flamary@unice.fr>
-# Nicolas Courty <ncourty@irisa.fr>
-# Michael Perrot <michael.perrot@univ-st-etienne.fr>
-# Leo Gautheron <https://github.com/aje>
-#
-# License: MIT License
-
-
-import cupy as np # np used for matrix computation
-import cupy as cp # cp used for cupy specific operations
-import numpy as npp
-from . import utils
-
-from .bregman import sinkhorn
-
-
-def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
- numInnerItermax=200, stopInnerThr=1e-9, verbose=False,
- log=False, to_numpy=True):
- """
- Solve the entropic regularization optimal transport problem with nonconvex
- group lasso regularization on GPU
-
- If the input matrix are in numpy format, they will be uploaded to the
- GPU first which can incur significant time overhead.
-
-
- The function solves the following optimization problem:
-
- .. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma)
- + \eta \Omega_g(\gamma)
-
- s.t. \gamma 1 = a
-
- \gamma^T 1= b
-
- \gamma\geq 0
- where :
-
- - M is the (ns,nt) metric cost matrix
- - :math:`\Omega_e` is the entropic regularization term
- :math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - :math:`\Omega_g` is the group lasso regulaization term
- :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^{1/2}_1`
- where :math:`\mathcal{I}_c` are the index of samples from class c
- in the source domain.
- - a and b are source and target weights (sum to 1)
-
- The algorithm used for solving the problem is the generalised conditional
- gradient as proposed in [5]_ [7]_
-
-
- Parameters
- ----------
- a : np.ndarray (ns,)
- samples weights in the source domain
- labels_a : np.ndarray (ns,)
- labels of samples in the source domain
- b : np.ndarray (nt,)
- samples weights in the target domain
- M : np.ndarray (ns,nt)
- loss matrix
- reg : float
- Regularization term for entropic regularization >0
- eta : float, optional
- Regularization term for group lasso regularization >0
- numItermax : int, optional
- Max number of iterations
- numInnerItermax : int, optional
- Max number of iterations (inner sinkhorn solver)
- stopInnerThr : float, optional
- Stop threshold on error (inner sinkhorn solver) (>0)
- verbose : bool, optional
- Print information along iterations
- log : bool, optional
- record log if True
- to_numpy : boolean, optional (default True)
- If true convert back the GPU array result to numpy format.
-
-
- Returns
- -------
- gamma : (ns x nt) ndarray
- Optimal transportation matrix for the given parameters
- log : dict
- log dictionary return only if log==True in parameters
-
-
- References
- ----------
-
- .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
- "Optimal Transport for Domain Adaptation," in IEEE
- Transactions on Pattern Analysis and Machine Intelligence ,
- vol.PP, no.99, pp.1-1
- .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015).
- Generalized conditional gradient: analysis of convergence
- and applications. arXiv preprint arXiv:1510.06567.
-
- See Also
- --------
- ot.lp.emd : Unregularized OT
- ot.bregman.sinkhorn : Entropic regularized OT
- ot.optim.cg : General regularized OT
-
- """
-
- a, labels_a, b, M = utils.to_gpu(a, labels_a, b, M)
-
- p = 0.5
- epsilon = 1e-3
-
- indices_labels = []
- labels_a2 = cp.asnumpy(labels_a)
- classes = npp.unique(labels_a2)
- for c in classes:
- idxc = utils.to_gpu(*npp.where(labels_a2 == c))
- indices_labels.append(idxc)
-
- W = np.zeros(M.shape)
-
- for cpt in range(numItermax):
- Mreg = M + eta * W
- transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
- stopThr=stopInnerThr, to_numpy=False)
- # the transport has been computed. Check if classes are really
- # separated
- W = np.ones(M.shape)
- for (i, c) in enumerate(classes):
-
- majs = np.sum(transp[indices_labels[i]], axis=0)
- majs = p * ((majs + epsilon)**(p - 1))
- W[indices_labels[i]] = majs
-
- if to_numpy:
- return utils.to_np(transp)
- else:
- return transp
diff --git a/ot/gpu/utils.py b/ot/gpu/utils.py
deleted file mode 100644
index 41e168a..0000000
--- a/ot/gpu/utils.py
+++ /dev/null
@@ -1,101 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-Utility functions for GPU
-"""
-
-# Author: Remi Flamary <remi.flamary@unice.fr>
-# Nicolas Courty <ncourty@irisa.fr>
-# Leo Gautheron <https://github.com/aje>
-#
-# License: MIT License
-
-import cupy as np # np used for matrix computation
-import cupy as cp # cp used for cupy specific operations
-
-
-def euclidean_distances(a, b, squared=False, to_numpy=True):
- """
- Compute the pairwise euclidean distance between matrices a and b.
-
- If the input matrix are in numpy format, they will be uploaded to the
- GPU first which can incur significant time overhead.
-
- Parameters
- ----------
- a : np.ndarray (n, f)
- first matrix
- b : np.ndarray (m, f)
- second matrix
- to_numpy : boolean, optional (default True)
- If true convert back the GPU array result to numpy format.
- squared : boolean, optional (default False)
- if True, return squared euclidean distance matrix
-
- Returns
- -------
- c : (n x m) np.ndarray or cupy.ndarray
- pairwise euclidean distance distance matrix
- """
-
- a, b = to_gpu(a, b)
-
- a2 = np.sum(np.square(a), 1)
- b2 = np.sum(np.square(b), 1)
-
- c = -2 * np.dot(a, b.T)
- c += a2[:, None]
- c += b2[None, :]
-
- if not squared:
- np.sqrt(c, out=c)
- if to_numpy:
- return to_np(c)
- else:
- return c
-
-
-def dist(x1, x2=None, metric='sqeuclidean', to_numpy=True):
- """Compute distance between samples in x1 and x2 on gpu
-
- Parameters
- ----------
-
- x1 : np.array (n1,d)
- matrix with n1 samples of size d
- x2 : np.array (n2,d), optional
- matrix with n2 samples of size d (if None then x2=x1)
- metric : str
- Metric from 'sqeuclidean', 'euclidean',
-
-
- Returns
- -------
-
- M : np.array (n1,n2)
- distance matrix computed with given metric
-
- """
- if x2 is None:
- x2 = x1
- if metric == "sqeuclidean":
- return euclidean_distances(x1, x2, squared=True, to_numpy=to_numpy)
- elif metric == "euclidean":
- return euclidean_distances(x1, x2, squared=False, to_numpy=to_numpy)
- else:
- raise NotImplementedError
-
-
-def to_gpu(*args):
- """ Upload numpy arrays to GPU and return them"""
- if len(args) > 1:
- return (cp.asarray(x) for x in args)
- else:
- return cp.asarray(args[0])
-
-
-def to_np(*args):
- """ convert GPU arras to numpy and return them"""
- if len(args) > 1:
- return (cp.asnumpy(x) for x in args)
- else:
- return cp.asnumpy(args[0])
diff --git a/ot/gromov.py b/ot/gromov.py
index 6544260..55ab0bd 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -7,6 +7,7 @@ Gromov-Wasserstein and Fused-Gromov-Wasserstein solvers
# Nicolas Courty <ncourty@irisa.fr>
# Rémi Flamary <remi.flamary@unice.fr>
# Titouan Vayer <titouan.vayer@irisa.fr>
+# Cédric Vincent-Cuaz <cedric.vincent-cuaz@inria.fr>
#
# License: MIT License
@@ -17,7 +18,7 @@ from .bregman import sinkhorn
from .utils import dist, UndefinedParameter, list_to_array
from .optim import cg
from .lp import emd_1d, emd
-from .utils import check_random_state
+from .utils import check_random_state, unif
from .backend import get_backend
@@ -320,7 +321,7 @@ def update_kl_loss(p, lambdas, T, Cs):
return nx.exp(tmpsum / ppt)
-def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, **kwargs):
+def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, G0=None, **kwargs):
r"""
Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
@@ -338,6 +339,10 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F
- :math:`\mathbf{q}`: distribution in the target space
- `L`: loss function to account for the misfit between the similarity matrices
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
Parameters
----------
C1 : array-like, shape (ns, ns)
@@ -361,6 +366,9 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F
armijo : bool, optional
If True the step of the line-search is found via an armijo research. Else closed form is used.
If there are convergence issues use False.
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
**kwargs : dict
parameters can be directly passed to the ot.optim.cg solver
@@ -385,18 +393,26 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F
"""
p, q = list_to_array(p, q)
-
p0, q0, C10, C20 = p, q, C1, C2
- nx = get_backend(p0, q0, C10, C20)
-
+ if G0 is None:
+ nx = get_backend(p0, q0, C10, C20)
+ else:
+ G0_ = G0
+ nx = get_backend(p0, q0, C10, C20, G0_)
p = nx.to_numpy(p)
q = nx.to_numpy(q)
C1 = nx.to_numpy(C10)
C2 = nx.to_numpy(C20)
- constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
+ if G0 is None:
+ G0 = p[:, None] * q[None, :]
+ else:
+ G0 = nx.to_numpy(G0_)
+ # Check marginals of G0
+ np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08)
+ np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08)
- G0 = p[:, None] * q[None, :]
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
def f(G):
return gwloss(constC, hC1, hC2, G)
@@ -414,7 +430,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F
return nx.from_numpy(cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=False, **kwargs), type_as=C10)
-def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, **kwargs):
+def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, G0=None, **kwargs):
r"""
Returns the gromov-wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
@@ -436,6 +452,10 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=
Note that when using backends, this loss function is differentiable wrt the
marices and weights for quadratic loss using the gradients from [38]_.
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
Parameters
----------
C1 : array-like, shape (ns, ns)
@@ -459,6 +479,9 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=
armijo : bool, optional
If True the step of the line-search is found via an armijo research. Else closed form is used.
If there are convergence issues use False.
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
Returns
-------
@@ -483,9 +506,12 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=
"""
p, q = list_to_array(p, q)
-
p0, q0, C10, C20 = p, q, C1, C2
- nx = get_backend(p0, q0, C10, C20)
+ if G0 is None:
+ nx = get_backend(p0, q0, C10, C20)
+ else:
+ G0_ = G0
+ nx = get_backend(p0, q0, C10, C20, G0_)
p = nx.to_numpy(p)
q = nx.to_numpy(q)
@@ -494,7 +520,13 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
- G0 = p[:, None] * q[None, :]
+ if G0 is None:
+ G0 = p[:, None] * q[None, :]
+ else:
+ G0 = nx.to_numpy(G0_)
+ # Check marginals of G0
+ np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08)
+ np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08)
def f(G):
return gwloss(constC, hC1, hC2, G)
@@ -514,10 +546,13 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=
gw = log_gw['gw_dist']
if loss_fun == 'square_loss':
- gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T))
- gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T))
+ gC1 = 2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T)
+ gC2 = 2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T)
+ gC1 = nx.from_numpy(gC1, type_as=C10)
+ gC2 = nx.from_numpy(gC2, type_as=C10)
gw = nx.set_gradients(gw, (p0, q0, C10, C20),
- (log_gw['u'], log_gw['v'], gC1, gC2))
+ (log_gw['u'] - nx.mean(log_gw['u']),
+ log_gw['v'] - nx.mean(log_gw['v']), gC1, gC2))
if log:
return gw, log_gw
@@ -525,7 +560,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=
return gw
-def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
+def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, G0=None, log=False, **kwargs):
r"""
Computes the FGW transport between two graphs (see :ref:`[24] <references-fused-gromov-wasserstein>`)
@@ -545,6 +580,10 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
- :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1)
- `L` is a loss function to account for the misfit between the similarity matrices
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] <references-fused-gromov-wasserstein>`
Parameters
@@ -566,6 +605,9 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
armijo : bool, optional
If True the step of the line-search is found via an armijo research. Else closed form is used.
If there are convergence issues use False.
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
log : bool, optional
record log if True
**kwargs : dict
@@ -588,20 +630,28 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
(ICML). 2019.
"""
p, q = list_to_array(p, q)
-
p0, q0, C10, C20, M0 = p, q, C1, C2, M
- nx = get_backend(p0, q0, C10, C20, M0)
+ if G0 is None:
+ nx = get_backend(p0, q0, C10, C20, M0)
+ else:
+ G0_ = G0
+ nx = get_backend(p0, q0, C10, C20, M0, G0_)
p = nx.to_numpy(p)
q = nx.to_numpy(q)
C1 = nx.to_numpy(C10)
C2 = nx.to_numpy(C20)
M = nx.to_numpy(M0)
+ if G0 is None:
+ G0 = p[:, None] * q[None, :]
+ else:
+ G0 = nx.to_numpy(G0_)
+ # Check marginals of G0
+ np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08)
+ np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08)
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
- G0 = p[:, None] * q[None, :]
-
def f(G):
return gwloss(constC, hC1, hC2, G)
@@ -610,19 +660,16 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
if log:
res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
-
fgw_dist = nx.from_numpy(log['loss'][-1], type_as=C10)
-
log['fgw_dist'] = fgw_dist
log['u'] = nx.from_numpy(log['u'], type_as=C10)
log['v'] = nx.from_numpy(log['v'], type_as=C10)
return nx.from_numpy(res, type_as=C10), log
-
else:
return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs), type_as=C10)
-def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
+def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, G0=None, log=False, **kwargs):
r"""
Computes the FGW distance between two graphs see (see :ref:`[24] <references-fused-gromov-wasserstein2>`)
@@ -645,6 +692,10 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
The algorithm used for solving the problem is conditional gradient as
discussed in :ref:`[24] <references-fused-gromov-wasserstein2>`
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
Note that when using backends, this loss function is differentiable wrt the
marices and weights for quadratic loss using the gradients from [38]_.
@@ -667,6 +718,9 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
armijo : bool, optional
If True the step of the line-search is found via an armijo research.
Else closed form is used. If there are convergence issues use False.
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
log : bool, optional
Record log if True.
**kwargs : dict
@@ -695,7 +749,11 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
p, q = list_to_array(p, q)
p0, q0, C10, C20, M0 = p, q, C1, C2, M
- nx = get_backend(p0, q0, C10, C20, M0)
+ if G0 is None:
+ nx = get_backend(p0, q0, C10, C20, M0)
+ else:
+ G0_ = G0
+ nx = get_backend(p0, q0, C10, C20, M0, G0_)
p = nx.to_numpy(p)
q = nx.to_numpy(q)
@@ -705,7 +763,13 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
- G0 = p[:, None] * q[None, :]
+ if G0 is None:
+ G0 = p[:, None] * q[None, :]
+ else:
+ G0 = nx.to_numpy(G0_)
+ # Check marginals of G0
+ np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08)
+ np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08)
def f(G):
return gwloss(constC, hC1, hC2, G)
@@ -725,10 +789,14 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
log_fgw['T'] = T0
if loss_fun == 'square_loss':
- gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T))
- gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T))
+ gC1 = 2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T)
+ gC2 = 2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T)
+ gC1 = nx.from_numpy(gC1, type_as=C10)
+ gC2 = nx.from_numpy(gC2, type_as=C10)
fgw_dist = nx.set_gradients(fgw_dist, (p0, q0, C10, C20, M0),
- (log_fgw['u'], log_fgw['v'], alpha * gC1, alpha * gC2, (1 - alpha) * T0))
+ (log_fgw['u'] - nx.mean(log_fgw['u']),
+ log_fgw['v'] - nx.mean(log_fgw['v']),
+ alpha * gC1, alpha * gC2, (1 - alpha) * T0))
if log:
return fgw_dist, log_fgw
@@ -1780,3 +1848,988 @@ def update_feature_matrix(lambdas, Ys, Ts, p):
for s in range(len(Ts))
])
return tmpsum
+
+
+def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate=1., Cdict_init=None, projection='nonnegative_symmetric', use_log=True,
+ tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs):
+ r"""
+ Infer Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, q) \}_{d \in [D]}` from the list of structures :math:`\{ (\mathbf{C_s},\mathbf{p_s}) \}_s`
+
+ .. math::
+ \min_{\mathbf{C_{dict}}, \{\mathbf{w_s} \}_{s \leq S}} \sum_{s=1}^S GW_2(\mathbf{C_s}, \sum_{d=1}^D w_{s,d}\mathbf{C_{dict}[d]}, \mathbf{p_s}, \mathbf{q}) - reg\| \mathbf{w_s} \|_2^2
+
+ such that, :math:`\forall s \leq S` :
+
+ - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1`
+ - :math:`\mathbf{w_s} \geq \mathbf{0}_D`
+
+ Where :
+
+ - :math:`\forall s \leq S, \mathbf{C_s}` is a (ns,ns) pairwise similarity matrix of variable size ns.
+ - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt.
+ - :math:`\forall s \leq S, \mathbf{p_s}` is the source distribution corresponding to :math:`\mathbf{C_s}`
+ - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space.
+ - reg is the regularization coefficient.
+
+ The stochastic algorithm used for estimating the graph dictionary atoms as proposed in [38]
+
+ Parameters
+ ----------
+ Cs : list of S symmetric array-like, shape (ns, ns)
+ List of Metric/Graph cost matrices of variable size (ns, ns).
+ D: int
+ Number of dictionary atoms to learn
+ nt: int
+ Number of samples within each dictionary atoms
+ reg : float, optional
+ Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0.
+ ps : list of S array-like, shape (ns,), optional
+ Distribution in each source space C of Cs. Default is None and corresponds to uniform distibutions.
+ q : array-like, shape (nt,), optional
+ Distribution in the embedding space whose structure will be learned. Default is None and corresponds to uniform distributions.
+ epochs: int, optional
+ Number of epochs used to learn the dictionary. Default is 32.
+ batch_size: int, optional
+ Batch size for each stochastic gradient update of the dictionary. Set to the dataset size if the provided batch_size is higher than the dataset size. Default is 32.
+ learning_rate: float, optional
+ Learning rate used for the stochastic gradient descent. Default is 1.
+ Cdict_init: list of D array-like with shape (nt, nt), optional
+ Used to initialize the dictionary.
+ If set to None (Default), the dictionary will be initialized randomly.
+ Else Cdict must have shape (D, nt, nt) i.e match provided shape features.
+ projection: str , optional
+ If 'nonnegative' and/or 'symmetric' is in projection, the corresponding projection will be performed at each stochastic update of the dictionary
+ Else the set of atoms is :math:`R^{nt * nt}`. Default is 'nonnegative_symmetric'
+ log: bool, optional
+ If set to True, losses evolution by batches and epochs are tracked. Default is False.
+ use_adam_optimizer: bool, optional
+ If set to True, adam optimizer with default settings is used as adaptative learning rate strategy.
+ Else perform SGD with fixed learning rate. Default is True.
+ tol_outer : float, optional
+ Solver precision for the BCD algorithm, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`.
+ tol_inner : float, optional
+ Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`.
+ max_iter_outer : int, optional
+ Maximum number of iterations for the BCD. Default is 20.
+ max_iter_inner : int, optional
+ Maximum number of iterations for the Conjugate Gradient. Default is 200.
+ verbose : bool, optional
+ Print the reconstruction loss every epoch. Default is False.
+
+ Returns
+ -------
+
+ Cdict_best_state : D array-like, shape (D,nt,nt)
+ Metric/Graph cost matrices composing the dictionary.
+ The dictionary leading to the best loss over an epoch is saved and returned.
+ log: dict
+ If use_log is True, contains loss evolutions by batches and epochs.
+ References
+ -------
+
+ ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty.
+ "Online Graph Dictionary Learning"
+ International Conference on Machine Learning (ICML). 2021.
+ """
+ # Handle backend of non-optional arguments
+ Cs0 = Cs
+ nx = get_backend(*Cs0)
+ Cs = [nx.to_numpy(C) for C in Cs0]
+ dataset_size = len(Cs)
+ # Handle backend of optional arguments
+ if ps is None:
+ ps = [unif(C.shape[0]) for C in Cs]
+ else:
+ ps = [nx.to_numpy(p) for p in ps]
+ if q is None:
+ q = unif(nt)
+ else:
+ q = nx.to_numpy(q)
+ if Cdict_init is None:
+ # Initialize randomly structures of dictionary atoms based on samples
+ dataset_means = [C.mean() for C in Cs]
+ Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt))
+ else:
+ Cdict = nx.to_numpy(Cdict_init).copy()
+ assert Cdict.shape == (D, nt, nt)
+
+ if 'symmetric' in projection:
+ Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1)))
+ if 'nonnegative' in projection:
+ Cdict[Cdict < 0.] = 0
+ if use_adam_optimizer:
+ adam_moments = _initialize_adam_optimizer(Cdict)
+
+ log = {'loss_batches': [], 'loss_epochs': []}
+ const_q = q[:, None] * q[None, :]
+ Cdict_best_state = Cdict.copy()
+ loss_best_state = np.inf
+ if batch_size > dataset_size:
+ batch_size = dataset_size
+ iter_by_epoch = dataset_size // batch_size + int((dataset_size % batch_size) > 0)
+
+ for epoch in range(epochs):
+ cumulated_loss_over_epoch = 0.
+
+ for _ in range(iter_by_epoch):
+ # batch sampling
+ batch = np.random.choice(range(dataset_size), size=batch_size, replace=False)
+ cumulated_loss_over_batch = 0.
+ unmixings = np.zeros((batch_size, D))
+ Cs_embedded = np.zeros((batch_size, nt, nt))
+ Ts = [None] * batch_size
+
+ for batch_idx, C_idx in enumerate(batch):
+ # BCD solver for Gromov-Wassersteisn linear unmixing used independently on each structure of the sampled batch
+ unmixings[batch_idx], Cs_embedded[batch_idx], Ts[batch_idx], current_loss = gromov_wasserstein_linear_unmixing(
+ Cs[C_idx], Cdict, reg=reg, p=ps[C_idx], q=q, tol_outer=tol_outer, tol_inner=tol_inner,
+ max_iter_outer=max_iter_outer, max_iter_inner=max_iter_inner
+ )
+ cumulated_loss_over_batch += current_loss
+ cumulated_loss_over_epoch += cumulated_loss_over_batch
+
+ if use_log:
+ log['loss_batches'].append(cumulated_loss_over_batch)
+
+ # Stochastic projected gradient step over dictionary atoms
+ grad_Cdict = np.zeros_like(Cdict)
+ for batch_idx, C_idx in enumerate(batch):
+ shared_term_structures = Cs_embedded[batch_idx] * const_q - (Cs[C_idx].dot(Ts[batch_idx])).T.dot(Ts[batch_idx])
+ grad_Cdict += unmixings[batch_idx][:, None, None] * shared_term_structures[None, :, :]
+ grad_Cdict *= 2 / batch_size
+ if use_adam_optimizer:
+ Cdict, adam_moments = _adam_stochastic_updates(Cdict, grad_Cdict, learning_rate, adam_moments)
+ else:
+ Cdict -= learning_rate * grad_Cdict
+ if 'symmetric' in projection:
+ Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1)))
+ if 'nonnegative' in projection:
+ Cdict[Cdict < 0.] = 0.
+
+ if use_log:
+ log['loss_epochs'].append(cumulated_loss_over_epoch)
+ if loss_best_state > cumulated_loss_over_epoch:
+ loss_best_state = cumulated_loss_over_epoch
+ Cdict_best_state = Cdict.copy()
+ if verbose:
+ print('--- epoch =', epoch, ' cumulated reconstruction error: ', cumulated_loss_over_epoch)
+
+ return nx.from_numpy(Cdict_best_state), log
+
+
+def _initialize_adam_optimizer(variable):
+
+ # Initialization for our numpy implementation of adam optimizer
+ atoms_adam_m = np.zeros_like(variable) # Initialize first moment tensor
+ atoms_adam_v = np.zeros_like(variable) # Initialize second moment tensor
+ atoms_adam_count = 1
+
+ return {'mean': atoms_adam_m, 'var': atoms_adam_v, 'count': atoms_adam_count}
+
+
+def _adam_stochastic_updates(variable, grad, learning_rate, adam_moments, beta_1=0.9, beta_2=0.99, eps=1e-09):
+
+ adam_moments['mean'] = beta_1 * adam_moments['mean'] + (1 - beta_1) * grad
+ adam_moments['var'] = beta_2 * adam_moments['var'] + (1 - beta_2) * (grad**2)
+ unbiased_m = adam_moments['mean'] / (1 - beta_1**adam_moments['count'])
+ unbiased_v = adam_moments['var'] / (1 - beta_2**adam_moments['count'])
+ variable -= learning_rate * unbiased_m / (np.sqrt(unbiased_v) + eps)
+ adam_moments['count'] += 1
+
+ return variable, adam_moments
+
+
+def gromov_wasserstein_linear_unmixing(C, Cdict, reg=0., p=None, q=None, tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, **kwargs):
+ r"""
+ Returns the Gromov-Wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary :math:`\{ (\mathbf{C_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}`.
+
+ .. math::
+ \min_{ \mathbf{w}} GW_2(\mathbf{C}, \sum_{d=1}^D w_d\mathbf{C_{dict}[d]}, \mathbf{p}, \mathbf{q}) - reg \| \mathbf{w} \|_2^2
+
+ such that:
+
+ - :math:`\mathbf{w}^\top \mathbf{1}_D = 1`
+ - :math:`\mathbf{w} \geq \mathbf{0}_D`
+
+ Where :
+
+ - :math:`\mathbf{C}` is the (ns,ns) pairwise similarity matrix.
+ - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrices of size nt.
+ - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights.
+ - reg is the regularization coefficient.
+
+ The algorithm used for solving the problem is a Block Coordinate Descent as discussed in [38], algorithm 1.
+
+ Parameters
+ ----------
+ C : array-like, shape (ns, ns)
+ Metric/Graph cost matrix.
+ Cdict : D array-like, shape (D,nt,nt)
+ Metric/Graph cost matrices composing the dictionary on which to embed C.
+ reg : float, optional.
+ Coefficient of the negative quadratic regularization used to promote sparsity of w. Default is 0.
+ p : array-like, shape (ns,), optional
+ Distribution in the source space C. Default is None and corresponds to uniform distribution.
+ q : array-like, shape (nt,), optional
+ Distribution in the space depicted by the dictionary. Default is None and corresponds to uniform distribution.
+ tol_outer : float, optional
+ Solver precision for the BCD algorithm.
+ tol_inner : float, optional
+ Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport. Default is :math:`10^{-5}`.
+ max_iter_outer : int, optional
+ Maximum number of iterations for the BCD. Default is 20.
+ max_iter_inner : int, optional
+ Maximum number of iterations for the Conjugate Gradient. Default is 200.
+
+ Returns
+ -------
+ w: array-like, shape (D,)
+ gromov-wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the span of the dictionary.
+ Cembedded: array-like, shape (nt,nt)
+ embedded structure of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{C_{dict}[d]}`.
+ T: array-like (ns, nt)
+ Gromov-Wasserstein transport plan between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w_d\mathbf{C_{dict}[d]}, \mathbf{q})`
+ current_loss: float
+ reconstruction error
+ References
+ -------
+
+ ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty.
+ "Online Graph Dictionary Learning"
+ International Conference on Machine Learning (ICML). 2021.
+ """
+ C0, Cdict0 = C, Cdict
+ nx = get_backend(C0, Cdict0)
+ C = nx.to_numpy(C0)
+ Cdict = nx.to_numpy(Cdict0)
+ if p is None:
+ p = unif(C.shape[0])
+ else:
+ p = nx.to_numpy(p)
+
+ if q is None:
+ q = unif(Cdict.shape[-1])
+ else:
+ q = nx.to_numpy(q)
+
+ T = p[:, None] * q[None, :]
+ D = len(Cdict)
+
+ w = unif(D) # Initialize uniformly the unmixing w
+ Cembedded = np.sum(w[:, None, None] * Cdict, axis=0)
+
+ const_q = q[:, None] * q[None, :]
+ # Trackers for BCD convergence
+ convergence_criterion = np.inf
+ current_loss = 10**15
+ outer_count = 0
+
+ while (convergence_criterion > tol_outer) and (outer_count < max_iter_outer):
+ previous_loss = current_loss
+ # 1. Solve GW transport between (C,p) and (\sum_d Cdictionary[d],q) fixing the unmixing w
+ T, log = gromov_wasserstein(C1=C, C2=Cembedded, p=p, q=q, loss_fun='square_loss', G0=T, log=True, armijo=False, **kwargs)
+ current_loss = log['gw_dist']
+ if reg != 0:
+ current_loss -= reg * np.sum(w**2)
+
+ # 2. Solve linear unmixing problem over w with a fixed transport plan T
+ w, Cembedded, current_loss = _cg_gromov_wasserstein_unmixing(
+ C=C, Cdict=Cdict, Cembedded=Cembedded, w=w, const_q=const_q, T=T,
+ starting_loss=current_loss, reg=reg, tol=tol_inner, max_iter=max_iter_inner, **kwargs
+ )
+
+ if previous_loss != 0:
+ convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss)
+ else: # handle numerical issues around 0
+ convergence_criterion = abs(previous_loss - current_loss) / 10**(-15)
+ outer_count += 1
+
+ return nx.from_numpy(w), nx.from_numpy(Cembedded), nx.from_numpy(T), nx.from_numpy(current_loss)
+
+
+def _cg_gromov_wasserstein_unmixing(C, Cdict, Cembedded, w, const_q, T, starting_loss, reg=0., tol=10**(-5), max_iter=200, **kwargs):
+ r"""
+ Returns for a fixed admissible transport plan,
+ the linear unmixing w minimizing the Gromov-Wasserstein cost between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w[d]*\mathbf{C_{dict}[d]}, \mathbf{q})`
+
+ .. math::
+ \min_{\mathbf{w}} \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D w_d*C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} - reg* \| \mathbf{w} \|_2^2
+
+
+ Such that:
+
+ - :math:`\mathbf{w}^\top \mathbf{1}_D = 1`
+ - :math:`\mathbf{w} \geq \mathbf{0}_D`
+
+ Where :
+
+ - :math:`\mathbf{C}` is the (ns,ns) pairwise similarity matrix.
+ - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrices of nt points.
+ - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights.
+ - :math:`\mathbf{w}` is the linear unmixing of :math:`(\mathbf{C}, \mathbf{p})` onto :math:`(\sum_d w_d \mathbf{Cdict[d]}, \mathbf{q})`.
+ - :math:`\mathbf{T}` is the optimal transport plan conditioned by the current state of :math:`\mathbf{w}`.
+ - reg is the regularization coefficient.
+
+ The algorithm used for solving the problem is a Conditional Gradient Descent as discussed in [38]
+
+ Parameters
+ ----------
+
+ C : array-like, shape (ns, ns)
+ Metric/Graph cost matrix.
+ Cdict : list of D array-like, shape (nt,nt)
+ Metric/Graph cost matrices composing the dictionary on which to embed C.
+ Each matrix in the dictionary must have the same size (nt,nt).
+ Cembedded: array-like, shape (nt,nt)
+ Embedded structure :math:`(\sum_d w[d]*Cdict[d],q)` of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary. Used to avoid redundant computations.
+ w: array-like, shape (D,)
+ Linear unmixing of the input structure onto the dictionary
+ const_q: array-like, shape (nt,nt)
+ product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations.
+ T: array-like, shape (ns,nt)
+ fixed transport plan between the input structure and its representation in the dictionary.
+ p : array-like, shape (ns,)
+ Distribution in the source space.
+ q : array-like, shape (nt,)
+ Distribution in the embedding space depicted by the dictionary.
+ reg : float, optional.
+ Coefficient of the negative quadratic regularization used to promote sparsity of w. Default is 0.
+
+ Returns
+ -------
+ w: ndarray (D,)
+ optimal unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary span given OT starting from previously optimal unmixing.
+ """
+ convergence_criterion = np.inf
+ current_loss = starting_loss
+ count = 0
+ const_TCT = np.transpose(C.dot(T)).dot(T)
+
+ while (convergence_criterion > tol) and (count < max_iter):
+
+ previous_loss = current_loss
+ # 1) Compute gradient at current point w
+ grad_w = 2 * np.sum(Cdict * (Cembedded[None, :, :] * const_q[None, :, :] - const_TCT[None, :, :]), axis=(1, 2))
+ grad_w -= 2 * reg * w
+
+ # 2) Conditional gradient direction finding: x= \argmin_x x^T.grad_w
+ min_ = np.min(grad_w)
+ x = (grad_w == min_).astype(np.float64)
+ x /= np.sum(x)
+
+ # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c
+ gamma, a, b, Cembedded_diff = _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, const_q, const_TCT, reg)
+
+ # 4) Updates: w <-- (1-gamma)*w + gamma*x
+ w += gamma * (x - w)
+ Cembedded += gamma * Cembedded_diff
+ current_loss += a * (gamma**2) + b * gamma
+
+ if previous_loss != 0: # not that the loss can be negative if reg >0
+ convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss)
+ else: # handle numerical issues around 0
+ convergence_criterion = abs(previous_loss - current_loss) / 10**(-15)
+ count += 1
+
+ return w, Cembedded, current_loss
+
+
+def _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, const_q, const_TCT, reg, **kwargs):
+ r"""
+ Compute optimal steps for the line search problem of Gromov-Wasserstein linear unmixing
+ .. math::
+ \min_{\gamma \in [0,1]} \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D z_d(\gamma)C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} - reg\| \mathbf{z}(\gamma) \|_2^2
+
+
+ Such that:
+
+ - :math:`\mathbf{z}(\gamma) = (1- \gamma)\mathbf{w} + \gamma \mathbf{x}`
+
+ Parameters
+ ----------
+
+ w : array-like, shape (D,)
+ Unmixing.
+ grad_w : array-like, shape (D, D)
+ Gradient of the reconstruction loss with respect to w.
+ x: array-like, shape (D,)
+ Conditional gradient direction.
+ Cdict : list of D array-like, shape (nt,nt)
+ Metric/Graph cost matrices composing the dictionary on which to embed C.
+ Each matrix in the dictionary must have the same size (nt,nt).
+ Cembedded: array-like, shape (nt,nt)
+ Embedded structure :math:`(\sum_d w_dCdict[d],q)` of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary. Used to avoid redundant computations.
+ const_q: array-like, shape (nt,nt)
+ product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations.
+ const_TCT: array-like, shape (nt, nt)
+ :math:`\mathbf{T}^\top \mathbf{C}^\top \mathbf{T}`. Used to avoid redundant computations.
+ Returns
+ -------
+ gamma: float
+ Optimal value for the line-search step
+ a: float
+ Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss
+ b: float
+ Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss
+ Cembedded_diff: numpy array, shape (nt, nt)
+ Difference between models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`.
+ reg : float, optional.
+ Coefficient of the negative quadratic regularization used to promote sparsity of :math:`\mathbf{w}`.
+ """
+
+ # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c
+ Cembedded_x = np.sum(x[:, None, None] * Cdict, axis=0)
+ Cembedded_diff = Cembedded_x - Cembedded
+ trace_diffx = np.sum(Cembedded_diff * Cembedded_x * const_q)
+ trace_diffw = np.sum(Cembedded_diff * Cembedded * const_q)
+ a = trace_diffx - trace_diffw
+ b = 2 * (trace_diffw - np.sum(Cembedded_diff * const_TCT))
+ if reg != 0:
+ a -= reg * np.sum((x - w)**2)
+ b -= 2 * reg * np.sum(w * (x - w))
+
+ if a > 0:
+ gamma = min(1, max(0, - b / (2 * a)))
+ elif a + b < 0:
+ gamma = 1
+ else:
+ gamma = 0
+
+ return gamma, a, b, Cembedded_diff
+
+
+def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate_C=1., learning_rate_Y=1.,
+ Cdict_init=None, Ydict_init=None, projection='nonnegative_symmetric', use_log=False,
+ tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs):
+ r"""
+ Infer Fused Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, \mathbf{Y_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}` from the list of S attributed structures :math:`\{ (\mathbf{C_s}, \mathbf{Y_s},\mathbf{p_s}) \}_s`
+
+ .. math::
+ \min_{\mathbf{C_{dict}},\mathbf{Y_{dict}}, \{\mathbf{w_s}\}_{s}} \sum_{s=1}^S FGW_{2,\alpha}(\mathbf{C_s}, \mathbf{Y_s}, \sum_{d=1}^D w_{s,d}\mathbf{C_{dict}[d]},\sum_{d=1}^D w_{s,d}\mathbf{Y_{dict}[d]}, \mathbf{p_s}, \mathbf{q}) \\ - reg\| \mathbf{w_s} \|_2^2
+
+
+ Such that :math:`\forall s \leq S` :
+
+ - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1`
+ - :math:`\mathbf{w_s} \geq \mathbf{0}_D`
+
+ Where :
+
+ - :math:`\forall s \leq S, \mathbf{C_s}` is a (ns,ns) pairwise similarity matrix of variable size ns.
+ - :math:`\forall s \leq S, \mathbf{Y_s}` is a (ns,d) features matrix of variable size ns and fixed dimension d.
+ - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt.
+ - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d.
+ - :math:`\forall s \leq S, \mathbf{p_s}` is the source distribution corresponding to :math:`\mathbf{C_s}`
+ - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space.
+ - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein
+ - reg is the regularization coefficient.
+
+
+ The stochastic algorithm used for estimating the attributed graph dictionary atoms as proposed in [38]
+
+ Parameters
+ ----------
+ Cs : list of S symmetric array-like, shape (ns, ns)
+ List of Metric/Graph cost matrices of variable size (ns,ns).
+ Ys : list of S array-like, shape (ns, d)
+ List of feature matrix of variable size (ns,d) with d fixed.
+ D: int
+ Number of dictionary atoms to learn
+ nt: int
+ Number of samples within each dictionary atoms
+ alpha : float
+ Trade-off parameter of Fused Gromov-Wasserstein
+ reg : float, optional
+ Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0.
+ ps : list of S array-like, shape (ns,), optional
+ Distribution in each source space C of Cs. Default is None and corresponds to uniform distibutions.
+ q : array-like, shape (nt,), optional
+ Distribution in the embedding space whose structure will be learned. Default is None and corresponds to uniform distributions.
+ epochs: int, optional
+ Number of epochs used to learn the dictionary. Default is 32.
+ batch_size: int, optional
+ Batch size for each stochastic gradient update of the dictionary. Set to the dataset size if the provided batch_size is higher than the dataset size. Default is 32.
+ learning_rate_C: float, optional
+ Learning rate used for the stochastic gradient descent on Cdict. Default is 1.
+ learning_rate_Y: float, optional
+ Learning rate used for the stochastic gradient descent on Ydict. Default is 1.
+ Cdict_init: list of D array-like with shape (nt, nt), optional
+ Used to initialize the dictionary structures Cdict.
+ If set to None (Default), the dictionary will be initialized randomly.
+ Else Cdict must have shape (D, nt, nt) i.e match provided shape features.
+ Ydict_init: list of D array-like with shape (nt, d), optional
+ Used to initialize the dictionary features Ydict.
+ If set to None, the dictionary features will be initialized randomly.
+ Else Ydict must have shape (D, nt, d) where d is the features dimension of inputs Ys and also match provided shape features.
+ projection: str, optional
+ If 'nonnegative' and/or 'symmetric' is in projection, the corresponding projection will be performed at each stochastic update of the dictionary
+ Else the set of atoms is :math:`R^{nt * nt}`. Default is 'nonnegative_symmetric'
+ log: bool, optional
+ If set to True, losses evolution by batches and epochs are tracked. Default is False.
+ use_adam_optimizer: bool, optional
+ If set to True, adam optimizer with default settings is used as adaptative learning rate strategy.
+ Else perform SGD with fixed learning rate. Default is True.
+ tol_outer : float, optional
+ Solver precision for the BCD algorithm, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`.
+ tol_inner : float, optional
+ Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`.
+ max_iter_outer : int, optional
+ Maximum number of iterations for the BCD. Default is 20.
+ max_iter_inner : int, optional
+ Maximum number of iterations for the Conjugate Gradient. Default is 200.
+ verbose : bool, optional
+ Print the reconstruction loss every epoch. Default is False.
+
+ Returns
+ -------
+
+ Cdict_best_state : D array-like, shape (D,nt,nt)
+ Metric/Graph cost matrices composing the dictionary.
+ The dictionary leading to the best loss over an epoch is saved and returned.
+ Ydict_best_state : D array-like, shape (D,nt,d)
+ Feature matrices composing the dictionary.
+ The dictionary leading to the best loss over an epoch is saved and returned.
+ log: dict
+ If use_log is True, contains loss evolutions by batches and epoches.
+ References
+ -------
+
+ ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty.
+ "Online Graph Dictionary Learning"
+ International Conference on Machine Learning (ICML). 2021.
+ """
+ Cs0, Ys0 = Cs, Ys
+ nx = get_backend(*Cs0, *Ys0)
+ Cs = [nx.to_numpy(C) for C in Cs0]
+ Ys = [nx.to_numpy(Y) for Y in Ys0]
+
+ d = Ys[0].shape[-1]
+ dataset_size = len(Cs)
+
+ if ps is None:
+ ps = [unif(C.shape[0]) for C in Cs]
+ else:
+ ps = [nx.to_numpy(p) for p in ps]
+ if q is None:
+ q = unif(nt)
+ else:
+ q = nx.to_numpy(q)
+
+ if Cdict_init is None:
+ # Initialize randomly structures of dictionary atoms based on samples
+ dataset_means = [C.mean() for C in Cs]
+ Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt))
+ else:
+ Cdict = nx.to_numpy(Cdict_init).copy()
+ assert Cdict.shape == (D, nt, nt)
+ if Ydict_init is None:
+ # Initialize randomly features of dictionary atoms based on samples distribution by feature component
+ dataset_feature_means = np.stack([F.mean(axis=0) for F in Ys])
+ Ydict = np.random.normal(loc=dataset_feature_means.mean(axis=0), scale=dataset_feature_means.std(axis=0), size=(D, nt, d))
+ else:
+ Ydict = nx.to_numpy(Ydict_init).copy()
+ assert Ydict.shape == (D, nt, d)
+
+ if 'symmetric' in projection:
+ Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1)))
+ if 'nonnegative' in projection:
+ Cdict[Cdict < 0.] = 0.
+
+ if use_adam_optimizer:
+ adam_moments_C = _initialize_adam_optimizer(Cdict)
+ adam_moments_Y = _initialize_adam_optimizer(Ydict)
+
+ log = {'loss_batches': [], 'loss_epochs': []}
+ const_q = q[:, None] * q[None, :]
+ diag_q = np.diag(q)
+ Cdict_best_state = Cdict.copy()
+ Ydict_best_state = Ydict.copy()
+ loss_best_state = np.inf
+ if batch_size > dataset_size:
+ batch_size = dataset_size
+ iter_by_epoch = dataset_size // batch_size + int((dataset_size % batch_size) > 0)
+
+ for epoch in range(epochs):
+ cumulated_loss_over_epoch = 0.
+
+ for _ in range(iter_by_epoch):
+
+ # Batch iterations
+ batch = np.random.choice(range(dataset_size), size=batch_size, replace=False)
+ cumulated_loss_over_batch = 0.
+ unmixings = np.zeros((batch_size, D))
+ Cs_embedded = np.zeros((batch_size, nt, nt))
+ Ys_embedded = np.zeros((batch_size, nt, d))
+ Ts = [None] * batch_size
+
+ for batch_idx, C_idx in enumerate(batch):
+ # BCD solver for Gromov-Wassersteisn linear unmixing used independently on each structure of the sampled batch
+ unmixings[batch_idx], Cs_embedded[batch_idx], Ys_embedded[batch_idx], Ts[batch_idx], current_loss = fused_gromov_wasserstein_linear_unmixing(
+ Cs[C_idx], Ys[C_idx], Cdict, Ydict, alpha, reg=reg, p=ps[C_idx], q=q,
+ tol_outer=tol_outer, tol_inner=tol_inner, max_iter_outer=max_iter_outer, max_iter_inner=max_iter_inner
+ )
+ cumulated_loss_over_batch += current_loss
+ cumulated_loss_over_epoch += cumulated_loss_over_batch
+ if use_log:
+ log['loss_batches'].append(cumulated_loss_over_batch)
+
+ # Stochastic projected gradient step over dictionary atoms
+ grad_Cdict = np.zeros_like(Cdict)
+ grad_Ydict = np.zeros_like(Ydict)
+
+ for batch_idx, C_idx in enumerate(batch):
+ shared_term_structures = Cs_embedded[batch_idx] * const_q - (Cs[C_idx].dot(Ts[batch_idx])).T.dot(Ts[batch_idx])
+ shared_term_features = diag_q.dot(Ys_embedded[batch_idx]) - Ts[batch_idx].T.dot(Ys[C_idx])
+ grad_Cdict += alpha * unmixings[batch_idx][:, None, None] * shared_term_structures[None, :, :]
+ grad_Ydict += (1 - alpha) * unmixings[batch_idx][:, None, None] * shared_term_features[None, :, :]
+ grad_Cdict *= 2 / batch_size
+ grad_Ydict *= 2 / batch_size
+
+ if use_adam_optimizer:
+ Cdict, adam_moments_C = _adam_stochastic_updates(Cdict, grad_Cdict, learning_rate_C, adam_moments_C)
+ Ydict, adam_moments_Y = _adam_stochastic_updates(Ydict, grad_Ydict, learning_rate_Y, adam_moments_Y)
+ else:
+ Cdict -= learning_rate_C * grad_Cdict
+ Ydict -= learning_rate_Y * grad_Ydict
+
+ if 'symmetric' in projection:
+ Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1)))
+ if 'nonnegative' in projection:
+ Cdict[Cdict < 0.] = 0.
+
+ if use_log:
+ log['loss_epochs'].append(cumulated_loss_over_epoch)
+ if loss_best_state > cumulated_loss_over_epoch:
+ loss_best_state = cumulated_loss_over_epoch
+ Cdict_best_state = Cdict.copy()
+ Ydict_best_state = Ydict.copy()
+ if verbose:
+ print('--- epoch: ', epoch, ' cumulated reconstruction error: ', cumulated_loss_over_epoch)
+
+ return nx.from_numpy(Cdict_best_state), nx.from_numpy(Ydict_best_state), log
+
+
+def fused_gromov_wasserstein_linear_unmixing(C, Y, Cdict, Ydict, alpha, reg=0., p=None, q=None, tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, **kwargs):
+ r"""
+ Returns the Fused Gromov-Wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` onto the attributed dictionary atoms :math:`\{ (\mathbf{C_{dict}[d]},\mathbf{Y_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}`
+
+ .. math::
+ \min_{\mathbf{w}} FGW_{2,\alpha}(\mathbf{C},\mathbf{Y}, \sum_{d=1}^D w_d\mathbf{C_{dict}[d]},\sum_{d=1}^D w_d\mathbf{Y_{dict}[d]}, \mathbf{p}, \mathbf{q}) - reg \| \mathbf{w} \|_2^2
+
+ such that, :math:`\forall s \leq S` :
+
+ - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1`
+ - :math:`\mathbf{w_s} \geq \mathbf{0}_D`
+
+ Where :
+
+ - :math:`\mathbf{C}` is a (ns,ns) pairwise similarity matrix of variable size ns.
+ - :math:`\mathbf{Y}` is a (ns,d) features matrix of variable size ns and fixed dimension d.
+ - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt.
+ - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d.
+ - :math:`\mathbf{p}` is the source distribution corresponding to :math:`\mathbf{C_s}`
+ - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space.
+ - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein
+ - reg is the regularization coefficient.
+
+ The algorithm used for solving the problem is a Block Coordinate Descent as discussed in [38], algorithm 6.
+
+ Parameters
+ ----------
+ C : array-like, shape (ns, ns)
+ Metric/Graph cost matrix.
+ Y : array-like, shape (ns, d)
+ Feature matrix.
+ Cdict : D array-like, shape (D,nt,nt)
+ Metric/Graph cost matrices composing the dictionary on which to embed (C,Y).
+ Ydict : D array-like, shape (D,nt,d)
+ Feature matrices composing the dictionary on which to embed (C,Y).
+ alpha: float,
+ Trade-off parameter of Fused Gromov-Wasserstein.
+ reg : float, optional
+ Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0.
+ p : array-like, shape (ns,), optional
+ Distribution in the source space C. Default is None and corresponds to uniform distribution.
+ q : array-like, shape (nt,), optional
+ Distribution in the space depicted by the dictionary. Default is None and corresponds to uniform distribution.
+ tol_outer : float, optional
+ Solver precision for the BCD algorithm.
+ tol_inner : float, optional
+ Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport. Default is :math:`10^{-5}`.
+ max_iter_outer : int, optional
+ Maximum number of iterations for the BCD. Default is 20.
+ max_iter_inner : int, optional
+ Maximum number of iterations for the Conjugate Gradient. Default is 200.
+
+ Returns
+ -------
+ w: array-like, shape (D,)
+ fused gromov-wasserstein linear unmixing of (C,Y,p) onto the span of the dictionary.
+ Cembedded: array-like, shape (nt,nt)
+ embedded structure of :math:`(\mathbf{C},\mathbf{Y}, \mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{C_{dict}[d]}`.
+ Yembedded: array-like, shape (nt,d)
+ embedded features of :math:`(\mathbf{C},\mathbf{Y}, \mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{Y_{dict}[d]}`.
+ T: array-like (ns,nt)
+ Fused Gromov-Wasserstein transport plan between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w_d\mathbf{C_{dict}[d]}, \sum_d w_d\mathbf{Y_{dict}[d]},\mathbf{q})`.
+ current_loss: float
+ reconstruction error
+ References
+ -------
+
+ ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty.
+ "Online Graph Dictionary Learning"
+ International Conference on Machine Learning (ICML). 2021.
+ """
+ C0, Y0, Cdict0, Ydict0 = C, Y, Cdict, Ydict
+ nx = get_backend(C0, Y0, Cdict0, Ydict0)
+ C = nx.to_numpy(C0)
+ Y = nx.to_numpy(Y0)
+ Cdict = nx.to_numpy(Cdict0)
+ Ydict = nx.to_numpy(Ydict0)
+
+ if p is None:
+ p = unif(C.shape[0])
+ else:
+ p = nx.to_numpy(p)
+ if q is None:
+ q = unif(Cdict.shape[-1])
+ else:
+ q = nx.to_numpy(q)
+
+ T = p[:, None] * q[None, :]
+ D = len(Cdict)
+ d = Y.shape[-1]
+ w = unif(D) # Initialize with uniform weights
+ ns = C.shape[-1]
+ nt = Cdict.shape[-1]
+
+ # modeling (C,Y)
+ Cembedded = np.sum(w[:, None, None] * Cdict, axis=0)
+ Yembedded = np.sum(w[:, None, None] * Ydict, axis=0)
+
+ # constants depending on q
+ const_q = q[:, None] * q[None, :]
+ diag_q = np.diag(q)
+ # Trackers for BCD convergence
+ convergence_criterion = np.inf
+ current_loss = 10**15
+ outer_count = 0
+ Ys_constM = (Y**2).dot(np.ones((d, nt))) # constant in computing euclidean pairwise feature matrix
+
+ while (convergence_criterion > tol_outer) and (outer_count < max_iter_outer):
+ previous_loss = current_loss
+
+ # 1. Solve GW transport between (C,p) and (\sum_d Cdictionary[d],q) fixing the unmixing w
+ Yt_varM = (np.ones((ns, d))).dot((Yembedded**2).T)
+ M = Ys_constM + Yt_varM - 2 * Y.dot(Yembedded.T) # euclidean distance matrix between features
+ T, log = fused_gromov_wasserstein(M, C, Cembedded, p, q, loss_fun='square_loss', alpha=alpha, armijo=False, G0=T, log=True)
+ current_loss = log['fgw_dist']
+ if reg != 0:
+ current_loss -= reg * np.sum(w**2)
+
+ # 2. Solve linear unmixing problem over w with a fixed transport plan T
+ w, Cembedded, Yembedded, current_loss = _cg_fused_gromov_wasserstein_unmixing(C, Y, Cdict, Ydict, Cembedded, Yembedded, w,
+ T, p, q, const_q, diag_q, current_loss, alpha, reg,
+ tol=tol_inner, max_iter=max_iter_inner, **kwargs)
+ if previous_loss != 0:
+ convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss)
+ else:
+ convergence_criterion = abs(previous_loss - current_loss) / 10**(-12)
+ outer_count += 1
+
+ return nx.from_numpy(w), nx.from_numpy(Cembedded), nx.from_numpy(Yembedded), nx.from_numpy(T), nx.from_numpy(current_loss)
+
+
+def _cg_fused_gromov_wasserstein_unmixing(C, Y, Cdict, Ydict, Cembedded, Yembedded, w, T, p, q, const_q, diag_q, starting_loss, alpha, reg, tol=10**(-6), max_iter=200, **kwargs):
+ r"""
+ Returns for a fixed admissible transport plan,
+ the optimal linear unmixing :math:`\mathbf{w}` minimizing the Fused Gromov-Wasserstein cost between :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` and :math:`(\sum_d w_d \mathbf{C_{dict}[d]},\sum_d w_d*\mathbf{Y_{dict}[d]}, \mathbf{q})`
+
+ .. math::
+ \min_{\mathbf{w}} \alpha \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D w_d C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} \\+ (1-\alpha) \sum_{ij} \| \mathbf{Y_i} - \sum_d w_d \mathbf{Y_{dict}[d]_j} \|_2^2 T_{ij}- reg \| \mathbf{w} \|_2^2
+
+ Such that :
+
+ - :math:`\mathbf{w}^\top \mathbf{1}_D = 1`
+ - :math:`\mathbf{w} \geq \mathbf{0}_D`
+
+ Where :
+
+ - :math:`\mathbf{C}` is a (ns,ns) pairwise similarity matrix of variable size ns.
+ - :math:`\mathbf{Y}` is a (ns,d) features matrix of variable size ns and fixed dimension d.
+ - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt.
+ - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d.
+ - :math:`\mathbf{p}` is the source distribution corresponding to :math:`\mathbf{C_s}`
+ - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space.
+ - :math:`\mathbf{T}` is the optimal transport plan conditioned by the previous state of :math:`\mathbf{w}`
+ - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein
+ - reg is the regularization coefficient.
+
+ The algorithm used for solving the problem is a Conditional Gradient Descent as discussed in [38], algorithm 7.
+
+ Parameters
+ ----------
+
+ C : array-like, shape (ns, ns)
+ Metric/Graph cost matrix.
+ Y : array-like, shape (ns, d)
+ Feature matrix.
+ Cdict : list of D array-like, shape (nt,nt)
+ Metric/Graph cost matrices composing the dictionary on which to embed (C,Y).
+ Each matrix in the dictionary must have the same size (nt,nt).
+ Ydict : list of D array-like, shape (nt,d)
+ Feature matrices composing the dictionary on which to embed (C,Y).
+ Each matrix in the dictionary must have the same size (nt,d).
+ Cembedded: array-like, shape (nt,nt)
+ Embedded structure of (C,Y) onto the dictionary
+ Yembedded: array-like, shape (nt,d)
+ Embedded features of (C,Y) onto the dictionary
+ w: array-like, shape (n_D,)
+ Linear unmixing of (C,Y) onto (Cdict,Ydict)
+ const_q: array-like, shape (nt,nt)
+ product matrix :math:`\mathbf{qq}^\top` where :math:`\mathbf{q}` is the target space distribution.
+ diag_q: array-like, shape (nt,nt)
+ diagonal matrix with values of q on the diagonal.
+ T: array-like, shape (ns,nt)
+ fixed transport plan between (C,Y) and its model
+ p : array-like, shape (ns,)
+ Distribution in the source space (C,Y).
+ q : array-like, shape (nt,)
+ Distribution in the embedding space depicted by the dictionary.
+ alpha: float,
+ Trade-off parameter of Fused Gromov-Wasserstein.
+ reg : float, optional
+ Coefficient of the negative quadratic regularization used to promote sparsity of w.
+
+ Returns
+ -------
+ w: ndarray (D,)
+ linear unmixing of :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` onto the span of :math:`(C_{dict},Y_{dict})` given OT corresponding to previous unmixing.
+ """
+ convergence_criterion = np.inf
+ current_loss = starting_loss
+ count = 0
+ const_TCT = np.transpose(C.dot(T)).dot(T)
+ ones_ns_d = np.ones(Y.shape)
+
+ while (convergence_criterion > tol) and (count < max_iter):
+ previous_loss = current_loss
+
+ # 1) Compute gradient at current point w
+ # structure
+ grad_w = alpha * np.sum(Cdict * (Cembedded[None, :, :] * const_q[None, :, :] - const_TCT[None, :, :]), axis=(1, 2))
+ # feature
+ grad_w += (1 - alpha) * np.sum(Ydict * (diag_q.dot(Yembedded)[None, :, :] - T.T.dot(Y)[None, :, :]), axis=(1, 2))
+ grad_w -= reg * w
+ grad_w *= 2
+
+ # 2) Conditional gradient direction finding: x= \argmin_x x^T.grad_w
+ min_ = np.min(grad_w)
+ x = (grad_w == min_).astype(np.float64)
+ x /= np.sum(x)
+
+ # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c
+ gamma, a, b, Cembedded_diff, Yembedded_diff = _linesearch_fused_gromov_wasserstein_unmixing(w, grad_w, x, Y, Cdict, Ydict, Cembedded, Yembedded, T, const_q, const_TCT, ones_ns_d, alpha, reg)
+
+ # 4) Updates: w <-- (1-gamma)*w + gamma*x
+ w += gamma * (x - w)
+ Cembedded += gamma * Cembedded_diff
+ Yembedded += gamma * Yembedded_diff
+ current_loss += a * (gamma**2) + b * gamma
+
+ if previous_loss != 0:
+ convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss)
+ else:
+ convergence_criterion = abs(previous_loss - current_loss) / 10**(-12)
+ count += 1
+
+ return w, Cembedded, Yembedded, current_loss
+
+
+def _linesearch_fused_gromov_wasserstein_unmixing(w, grad_w, x, Y, Cdict, Ydict, Cembedded, Yembedded, T, const_q, const_TCT, ones_ns_d, alpha, reg, **kwargs):
+ r"""
+ Compute optimal steps for the line search problem of Fused Gromov-Wasserstein linear unmixing
+ .. math::
+ \min_{\gamma \in [0,1]} \alpha \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D z_d(\gamma)C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} \\ + (1-\alpha) \sum_{ij} \| \mathbf{Y_i} - \sum_d z_d(\gamma) \mathbf{Y_{dict}[d]_j} \|_2^2 - reg\| \mathbf{z}(\gamma) \|_2^2
+
+
+ Such that :
+
+ - :math:`\mathbf{z}(\gamma) = (1- \gamma)\mathbf{w} + \gamma \mathbf{x}`
+
+ Parameters
+ ----------
+
+ w : array-like, shape (D,)
+ Unmixing.
+ grad_w : array-like, shape (D, D)
+ Gradient of the reconstruction loss with respect to w.
+ x: array-like, shape (D,)
+ Conditional gradient direction.
+ Y: arrat-like, shape (ns,d)
+ Feature matrix of the input space
+ Cdict : list of D array-like, shape (nt, nt)
+ Metric/Graph cost matrices composing the dictionary on which to embed (C,Y).
+ Each matrix in the dictionary must have the same size (nt,nt).
+ Ydict : list of D array-like, shape (nt, d)
+ Feature matrices composing the dictionary on which to embed (C,Y).
+ Each matrix in the dictionary must have the same size (nt,d).
+ Cembedded: array-like, shape (nt, nt)
+ Embedded structure of (C,Y) onto the dictionary
+ Yembedded: array-like, shape (nt, d)
+ Embedded features of (C,Y) onto the dictionary
+ T: array-like, shape (ns, nt)
+ Fixed transport plan between (C,Y) and its current model.
+ const_q: array-like, shape (nt,nt)
+ product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations.
+ const_TCT: array-like, shape (nt, nt)
+ :math:`\mathbf{T}^\top \mathbf{C}^\top \mathbf{T}`. Used to avoid redundant computations.
+ ones_ns_d: array-like, shape (ns, d)
+ :math:`\mathbf{1}_{ ns \times d}`. Used to avoid redundant computations.
+ alpha: float,
+ Trade-off parameter of Fused Gromov-Wasserstein.
+ reg : float, optional
+ Coefficient of the negative quadratic regularization used to promote sparsity of w.
+
+ Returns
+ -------
+ gamma: float
+ Optimal value for the line-search step
+ a: float
+ Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss
+ b: float
+ Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss
+ Cembedded_diff: numpy array, shape (nt, nt)
+ Difference between structure matrix of models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`.
+ Yembedded_diff: numpy array, shape (nt, nt)
+ Difference between feature matrix of models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`.
+ """
+ # polynomial coefficients from quadratic objective (with respect to w) on structures
+ Cembedded_x = np.sum(x[:, None, None] * Cdict, axis=0)
+ Cembedded_diff = Cembedded_x - Cembedded
+ trace_diffx = np.sum(Cembedded_diff * Cembedded_x * const_q)
+ trace_diffw = np.sum(Cembedded_diff * Cembedded * const_q)
+ # Constant factor appearing in the factorization a*gamma^2 + b*g + c of the Gromov-Wasserstein reconstruction loss
+ a_gw = trace_diffx - trace_diffw
+ b_gw = 2 * (trace_diffw - np.sum(Cembedded_diff * const_TCT))
+
+ # polynomial coefficient from quadratic objective (with respect to w) on features
+ Yembedded_x = np.sum(x[:, None, None] * Ydict, axis=0)
+ Yembedded_diff = Yembedded_x - Yembedded
+ # Constant factor appearing in the factorization a*gamma^2 + b*g + c of the Gromov-Wasserstein reconstruction loss
+ a_w = np.sum(ones_ns_d.dot((Yembedded_diff**2).T) * T)
+ b_w = 2 * np.sum(T * (ones_ns_d.dot((Yembedded * Yembedded_diff).T) - Y.dot(Yembedded_diff.T)))
+
+ a = alpha * a_gw + (1 - alpha) * a_w
+ b = alpha * b_gw + (1 - alpha) * b_w
+ if reg != 0:
+ a -= reg * np.sum((x - w)**2)
+ b -= 2 * reg * np.sum(w * (x - w))
+ if a > 0:
+ gamma = min(1, max(0, -b / (2 * a)))
+ elif a + b < 0:
+ gamma = 1
+ else:
+ gamma = 0
+
+ return gamma, a, b, Cembedded_diff, Yembedded_diff
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 5da897d..390c32d 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -26,6 +26,8 @@ from ..utils import dist, list_to_array
from ..utils import parmap
from ..backend import get_backend
+
+
__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted',
'emd_1d', 'emd2_1d', 'wasserstein_1d']
@@ -220,7 +222,15 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
format
.. note:: This function is backend-compatible and will work on arrays
- from all compatible backends.
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
+ .. note:: This function will cast the computed transport plan to the data type
+ of the provided input with the following priority: :math:`\mathbf{a}`,
+ then :math:`\mathbf{b}`, then :math:`\mathbf{M}` if marginals are not provided.
+ Casting to an integer tensor might result in a loss of precision.
+ If this behaviour is unwanted, please make sure to provide a
+ floating point input.
Uses the algorithm proposed in :ref:`[1] <references-emd>`.
@@ -287,12 +297,16 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
a, b, M = list_to_array(a, b, M)
a0, b0, M0 = a, b, M
+ if len(a0) != 0:
+ type_as = a0
+ elif len(b0) != 0:
+ type_as = b0
+ else:
+ type_as = M0
nx = get_backend(M0, a0, b0)
# convert to numpy
- M = nx.to_numpy(M)
- a = nx.to_numpy(a)
- b = nx.to_numpy(b)
+ M, a, b = nx.to_numpy(M, a, b)
# ensure float64
a = np.asarray(a, dtype=np.float64)
@@ -327,15 +341,23 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
u, v = estimate_dual_null_weights(u, v, a, b, M)
result_code_string = check_result(result_code)
+ if not nx.is_floating_point(type_as):
+ warnings.warn(
+ "Input histogram consists of integer. The transport plan will be "
+ "casted accordingly, possibly resulting in a loss of precision. "
+ "If this behaviour is unwanted, please make sure your input "
+ "histogram consists of floating point elements.",
+ stacklevel=2
+ )
if log:
log = {}
log['cost'] = cost
- log['u'] = nx.from_numpy(u, type_as=a0)
- log['v'] = nx.from_numpy(v, type_as=b0)
+ log['u'] = nx.from_numpy(u, type_as=type_as)
+ log['v'] = nx.from_numpy(v, type_as=type_as)
log['warning'] = result_code_string
log['result_code'] = result_code
- return nx.from_numpy(G, type_as=M0), log
- return nx.from_numpy(G, type_as=M0)
+ return nx.from_numpy(G, type_as=type_as), log
+ return nx.from_numpy(G, type_as=type_as)
def emd2(a, b, M, processes=1,
@@ -358,7 +380,16 @@ def emd2(a, b, M, processes=1,
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
.. note:: This function is backend-compatible and will work on arrays
- from all compatible backends.
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
+ .. note:: This function will cast the computed transport plan and
+ transportation loss to the data type of the provided input with the
+ following priority: :math:`\mathbf{a}`, then :math:`\mathbf{b}`,
+ then :math:`\mathbf{M}` if marginals are not provided.
+ Casting to an integer tensor might result in a loss of precision.
+ If this behaviour is unwanted, please make sure to provide a
+ floating point input.
Uses the algorithm proposed in :ref:`[1] <references-emd2>`.
@@ -428,12 +459,16 @@ def emd2(a, b, M, processes=1,
a, b, M = list_to_array(a, b, M)
a0, b0, M0 = a, b, M
+ if len(a0) != 0:
+ type_as = a0
+ elif len(b0) != 0:
+ type_as = b0
+ else:
+ type_as = M0
nx = get_backend(M0, a0, b0)
# convert to numpy
- M = nx.to_numpy(M)
- a = nx.to_numpy(a)
- b = nx.to_numpy(b)
+ M, a, b = nx.to_numpy(M, a, b)
a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
@@ -466,15 +501,24 @@ def emd2(a, b, M, processes=1,
result_code_string = check_result(result_code)
log = {}
- G = nx.from_numpy(G, type_as=M0)
+ if not nx.is_floating_point(type_as):
+ warnings.warn(
+ "Input histogram consists of integer. The transport plan will be "
+ "casted accordingly, possibly resulting in a loss of precision. "
+ "If this behaviour is unwanted, please make sure your input "
+ "histogram consists of floating point elements.",
+ stacklevel=2
+ )
+ G = nx.from_numpy(G, type_as=type_as)
if return_matrix:
log['G'] = G
- log['u'] = nx.from_numpy(u, type_as=a0)
- log['v'] = nx.from_numpy(v, type_as=b0)
+ log['u'] = nx.from_numpy(u, type_as=type_as)
+ log['v'] = nx.from_numpy(v, type_as=type_as)
log['warning'] = result_code_string
log['result_code'] = result_code
- cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0),
- (a0, b0, M0), (log['u'], log['v'], G))
+ cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as),
+ (a0, b0, M0), (log['u'] - nx.mean(log['u']),
+ log['v'] - nx.mean(log['v']), G))
return [cost, log]
else:
def f(b):
@@ -487,10 +531,18 @@ def emd2(a, b, M, processes=1,
if np.any(~asel) or np.any(~bsel):
u, v = estimate_dual_null_weights(u, v, a, b, M)
- G = nx.from_numpy(G, type_as=M0)
- cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0),
- (a0, b0, M0), (nx.from_numpy(u, type_as=a0),
- nx.from_numpy(v, type_as=b0), G))
+ if not nx.is_floating_point(type_as):
+ warnings.warn(
+ "Input histogram consists of integer. The transport plan will be "
+ "casted accordingly, possibly resulting in a loss of precision. "
+ "If this behaviour is unwanted, please make sure your input "
+ "histogram consists of floating point elements.",
+ stacklevel=2
+ )
+ G = nx.from_numpy(G, type_as=type_as)
+ cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as),
+ (a0, b0, M0), (nx.from_numpy(u - np.mean(u), type_as=type_as),
+ nx.from_numpy(v - np.mean(v), type_as=type_as), G))
check_result(result_code)
return cost
@@ -535,18 +587,18 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
Parameters
----------
- measures_locations : list of N (k_i,d) numpy.ndarray
+ measures_locations : list of N (k_i,d) array-like
The discrete support of a measure supported on :math:`k_i` locations of a `d`-dimensional space
(:math:`k_i` can be different for each element of the list)
- measures_weights : list of N (k_i,) numpy.ndarray
+ measures_weights : list of N (k_i,) array-like
Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one
representing the weights of each discrete input measure
- X_init : (k,d) np.ndarray
+ X_init : (k,d) array-like
Initialization of the support locations (on `k` atoms) of the barycenter
- b : (k,) np.ndarray
+ b : (k,) array-like
Initialization of the weights of the barycenter (non-negatives, sum to 1)
- weights : (N,) np.ndarray
+ weights : (N,) array-like
Initialization of the coefficients of the barycenter (non-negatives, sum to 1)
numItermax : int, optional
@@ -564,7 +616,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
Returns
-------
- X : (k,d) np.ndarray
+ X : (k,d) array-like
Support locations (on k atoms) of the barycenter
@@ -577,15 +629,17 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
"""
+ nx = get_backend(*measures_locations,*measures_weights,X_init)
+
iter_count = 0
N = len(measures_locations)
k = X_init.shape[0]
d = X_init.shape[1]
if b is None:
- b = np.ones((k,)) / k
+ b = nx.ones((k,),type_as=X_init) / k
if weights is None:
- weights = np.ones((N,)) / N
+ weights = nx.ones((N,),type_as=X_init) / N
X = X_init
@@ -596,15 +650,15 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
while (displacement_square_norm > stopThr and iter_count < numItermax):
- T_sum = np.zeros((k, d))
+ T_sum = nx.zeros((k, d),type_as=X_init)
+
- for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights,
- weights.tolist()):
+ for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights):
M_i = dist(X, measure_locations_i)
T_i = emd(b, measure_weights_i, M_i, numThreads=numThreads)
- T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i)
+ T_sum = T_sum + weight_i * 1. / b[:,None] * nx.dot(T_i, measure_locations_i)
- displacement_square_norm = np.sum(np.square(T_sum - X))
+ displacement_square_norm = nx.sum((T_sum - X)**2)
if log:
displacement_square_norms.append(displacement_square_norm)
@@ -620,3 +674,4 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
return X, log_dict
else:
return X
+
diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py
index 869d450..fbf3c0e 100644
--- a/ot/lp/cvx.py
+++ b/ot/lp/cvx.py
@@ -11,7 +11,6 @@ import numpy as np
import scipy as sp
import scipy.sparse as sps
-
try:
import cvxopt
from cvxopt import solvers, matrix, spmatrix
diff --git a/ot/optim.py b/ot/optim.py
index f25e2c9..5a1d605 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -9,12 +9,19 @@ Generic solvers for regularized OT
# License: MIT License
import numpy as np
-from scipy.optimize.linesearch import scalar_search_armijo
+import warnings
from .lp import emd
from .bregman import sinkhorn
-from ot.utils import list_to_array
+from .utils import list_to_array
from .backend import get_backend
+with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ try:
+ from scipy.optimize import scalar_search_armijo
+ except ImportError:
+ from scipy.optimize.linesearch import scalar_search_armijo
+
# The corresponding scipy function does not work for matrices
diff --git a/ot/partial.py b/ot/partial.py
index b7093e4..0a9e450 100755
--- a/ot/partial.py
+++ b/ot/partial.py
@@ -7,7 +7,6 @@ Partial OT solvers
# License: MIT License
import numpy as np
-
from .lp import emd
@@ -29,7 +28,8 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
\gamma &\geq 0
- \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\}
+ \mathbf{1}^T \gamma^T \mathbf{1} = m &
+ \leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\}
or equivalently (see Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X.
@@ -50,7 +50,8 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
- :math:`\lambda` is the lagrangian cost. Tuning its value allows attaining
a given mass to be transported `m`
- The formulation of the problem has been proposed in :ref:`[28] <references-partial-wasserstein-lagrange>`
+ The formulation of the problem has been proposed in
+ :ref:`[28] <references-partial-wasserstein-lagrange>`
Parameters
@@ -261,7 +262,7 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
b_extended = np.append(b, [(np.sum(a) - m) / nb_dummies] * nb_dummies)
a_extended = np.append(a, [(np.sum(b) - m) / nb_dummies] * nb_dummies)
M_extended = np.zeros((len(a_extended), len(b_extended)))
- M_extended[-nb_dummies:, -nb_dummies:] = np.max(M) * 1e5
+ M_extended[-nb_dummies:, -nb_dummies:] = np.max(M) * 2
M_extended[:len(a), :len(b)] = M
gamma, log_emd = emd(a_extended, b_extended, M_extended, log=True,
@@ -455,7 +456,8 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
- `m` is the amount of mass to be transported
- The formulation of the problem has been proposed in :ref:`[29] <references-partial-gromov-wasserstein>`
+ The formulation of the problem has been proposed in
+ :ref:`[29] <references-partial-gromov-wasserstein>`
Parameters
@@ -469,7 +471,8 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
q : ndarray, shape (nt,)
Distribution in the target space
m : float, optional
- Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`)
+ Amount of mass to be transported
+ (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`)
nb_dummies : int, optional
Number of dummy points to add (avoid instabilities in the EMD solver)
G0 : ndarray, shape (ns, nt), optional
@@ -623,16 +626,19 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
\gamma &\geq 0
- \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\}
+ \mathbf{1}^T \gamma^T \mathbf{1} = m
+ &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\}
where :
- :math:`\mathbf{M}` is the metric cost matrix
- - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\Omega` is the entropic regularization term,
+ :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
- `m` is the amount of mass to be transported
- The formulation of the problem has been proposed in :ref:`[29] <references-partial-gromov-wasserstein2>`
+ The formulation of the problem has been proposed in
+ :ref:`[29] <references-partial-gromov-wasserstein2>`
Parameters
@@ -646,7 +652,8 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
q : ndarray, shape (nt,)
Distribution in the target space
m : float, optional
- Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`)
+ Amount of mass to be transported
+ (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`)
nb_dummies : int, optional
Number of dummy points to add (avoid instabilities in the EMD solver)
G0 : ndarray, shape (ns, nt), optional
@@ -728,21 +735,25 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
The function considers the following problem:
.. math::
- \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg} \cdot\Omega(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma,
+ \mathbf{M} \rangle_F + \mathrm{reg} \cdot\Omega(\gamma)
s.t. \gamma \mathbf{1} &\leq \mathbf{a} \\
\gamma^T \mathbf{1} &\leq \mathbf{b} \\
\gamma &\geq 0 \\
- \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} \\
+ \mathbf{1}^T \gamma^T \mathbf{1} = m
+ &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} \\
where :
- :math:`\mathbf{M}` is the metric cost matrix
- - :math:`\Omega` is the entropic regularization term, :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\Omega` is the entropic regularization term,
+ :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
- `m` is the amount of mass to be transported
- The formulation of the problem has been proposed in :ref:`[3] <references-entropic-partial-wasserstein>` (prop. 5)
+ The formulation of the problem has been proposed in
+ :ref:`[3] <references-entropic-partial-wasserstein>` (prop. 5)
Parameters
@@ -829,12 +840,23 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
np.multiply(K, m / np.sum(K), out=K)
err, cpt = 1, 0
+ q1 = np.ones(K.shape)
+ q2 = np.ones(K.shape)
+ q3 = np.ones(K.shape)
while (err > stopThr and cpt < numItermax):
Kprev = K
+ K = K * q1
K1 = np.dot(np.diag(np.minimum(a / np.sum(K, axis=1), dx)), K)
+ q1 = q1 * Kprev / K1
+ K1prev = K1
+ K1 = K1 * q2
K2 = np.dot(K1, np.diag(np.minimum(b / np.sum(K1, axis=0), dy)))
+ q2 = q2 * K1prev / K2
+ K2prev = K2
+ K2 = K2 * q3
K = K2 * (m / np.sum(K2))
+ q3 = q3 * K2prev / K
if np.any(np.isnan(K)) or np.any(np.isinf(K)):
print('Warning: numerical errors at iteration', cpt)
@@ -861,7 +883,8 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None,
numItermax=1000, tol=1e-7, log=False,
verbose=False):
r"""
- Returns the partial Gromov-Wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
+ Returns the partial Gromov-Wasserstein transport between
+ :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
The function solves the following optimization problem:
@@ -877,7 +900,8 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None,
\gamma^T \mathbf{1} &\leq \mathbf{b}
- \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\}
+ \mathbf{1}^T \gamma^T \mathbf{1} = m
+ &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\}
where :
@@ -885,10 +909,13 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None,
- :math:`\mathbf{C_2}` is the metric cost matrix in the target space
- :math:`\mathbf{p}` and :math:`\mathbf{q}` are the sample weights
- `L`: quadratic loss function
- - :math:`\Omega` is the entropic regularization term, :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\Omega` is the entropic regularization term,
+ :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- `m` is the amount of mass to be transported
- The formulation of the GW problem has been proposed in :ref:`[12] <references-entropic-partial-gromov-wassertein>` and the partial GW in :ref:`[29] <references-entropic-partial-gromov-wassertein>`
+ The formulation of the GW problem has been proposed in
+ :ref:`[12] <references-entropic-partial-gromov-wassertein>` and the
+ partial GW in :ref:`[29] <references-entropic-partial-gromov-wassertein>`
Parameters
----------
@@ -903,7 +930,8 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None,
reg: float
entropic regularization parameter
m : float, optional
- Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`)
+ Amount of mass to be transported (default:
+ :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`)
G0 : ndarray, shape (ns, nt), optional
Initialisation of the transportation matrix
numItermax : int, optional
@@ -1005,13 +1033,15 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None,
numItermax=1000, tol=1e-7, log=False,
verbose=False):
r"""
- Returns the partial Gromov-Wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
+ Returns the partial Gromov-Wasserstein discrepancy between
+ :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
The function solves the following optimization problem:
.. math::
- GW = \min_{\gamma} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l})\cdot
- \gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma)
+ GW = \min_{\gamma} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k},
+ \mathbf{C_2}_{j,l})\cdot
+ \gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma)
.. math::
s.t. \ \gamma &\geq 0
@@ -1028,10 +1058,13 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None,
- :math:`\mathbf{C_2}` is the metric cost matrix in the target space
- :math:`\mathbf{p}` and :math:`\mathbf{q}` are the sample weights
- `L` : quadratic loss function
- - :math:`\Omega` is the entropic regularization term, :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\Omega` is the entropic regularization term,
+ :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- `m` is the amount of mass to be transported
- The formulation of the GW problem has been proposed in :ref:`[12] <references-entropic-partial-gromov-wassertein2>` and the partial GW in :ref:`[29] <references-entropic-partial-gromov-wassertein2>`
+ The formulation of the GW problem has been proposed in
+ :ref:`[12] <references-entropic-partial-gromov-wassertein2>` and the
+ partial GW in :ref:`[29] <references-entropic-partial-gromov-wassertein2>`
Parameters
@@ -1047,7 +1080,8 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None,
reg: float
entropic regularization parameter
m : float, optional
- Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`)
+ Amount of mass to be transported (default:
+ :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`)
G0 : ndarray, shape (ns, nt), optional
Initialisation of the transportation matrix
numItermax : int, optional
diff --git a/ot/plot.py b/ot/plot.py
index 2208c90..8ade2eb 100644
--- a/ot/plot.py
+++ b/ot/plot.py
@@ -85,8 +85,13 @@ def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs):
if ('color' not in kwargs) and ('c' not in kwargs):
kwargs['color'] = 'k'
mx = G.max()
+ if 'alpha' in kwargs:
+ scale = kwargs['alpha']
+ del kwargs['alpha']
+ else:
+ scale = 1
for i in range(xs.shape[0]):
for j in range(xt.shape[0]):
if G[i, j] / mx > thr:
pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]],
- alpha=G[i, j] / mx, **kwargs)
+ alpha=G[i, j] / mx * scale, **kwargs)
diff --git a/ot/regpath.py b/ot/regpath.py
index 269937a..e745288 100644
--- a/ot/regpath.py
+++ b/ot/regpath.py
@@ -11,34 +11,48 @@ import scipy.sparse as sp
def recast_ot_as_lasso(a, b, C):
- r"""This function recasts the l2-penalized UOT problem as a Lasso problem
+ r"""This function recasts the l2-penalized UOT problem as a Lasso problem.
+
+ Recall the l2-penalized UOT problem defined in
+ :ref:`[41] <references-regpath>`
- Recall the l2-penalized UOT problem defined in [Chapel et al., 2021]
.. math::
- UOT = \min_T <C, T> + \lambda \|T 1_m - a\|_2^2 +
- \lambda \|T^T 1_n - b\|_2^2
+ \text{UOT}_{\lambda} = \min_T <C, T> + \lambda \|T 1_m -
+ \mathbf{a}\|_2^2 +
+ \lambda \|T^T 1_n - \mathbf{b}\|_2^2
+
s.t.
T \geq 0
+
where :
- - C is the (dim_a, dim_b) metric cost matrix
- - :math:`\lambda` is the l2-regularization coefficient
- - a and b are source and target distributions
- - T is the transport plan to optimize
- The problem above can be reformulated to a non-negative penalized
+ - :math:`C` is the cost matrix
+ - :math:`\lambda` is the l2-regularization parameter
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the source and target \
+ distributions
+ - :math:`T` is the transport plan to optimize
+
+ The problem above can be reformulated as a non-negative penalized
linear regression problem, particularly Lasso
+
.. math::
- UOT2 = \min_t \gamma c^T t + 0.5 * \|H t - y\|_2^2
+ \text{UOT2}_{\lambda} = \min_{\mathbf{t}} \gamma \mathbf{c}^T
+ \mathbf{t} + 0.5 * \|H \mathbf{t} - \mathbf{y}\|_2^2
+
s.t.
- t \geq 0
+ \mathbf{t} \geq 0
+
where :
- - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C)
- - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient
- - y is the concatenation of vectors a and b, defined as y^T = [a^T b^T]
- - H is a (dim_a + dim_b, dim_a * dim_b) metric matrix,
- see [Chapel et al., 2021] for the design of H. The matrix product H t
- computes both the source marginal and the target marginal.
- - t is a (dim_a * dim_b, ) metric vector (flattened version of T)
+
+ - :math:`\mathbf{c}` is the flattened version of the cost matrix :math:`C`
+ - :math:`\mathbf{y}` is the concatenation of vectors :math:`\mathbf{a}` \
+ and :math:`\mathbf{b}`
+ - :math:`H` is a metric matrix, see :ref:`[41] <references-regpath>` for \
+ the design of :math:`H`. The matrix product :math:`H\mathbf{t}` \
+ computes both the source marginal and the target marginals.
+ - :math:`\mathbf{t}` is the flattened version of the transport plan \
+ :math:`T`
+
Parameters
----------
a : np.ndarray (dim_a,)
@@ -47,14 +61,16 @@ def recast_ot_as_lasso(a, b, C):
Histogram of dimension dim_b
C : np.ndarray, shape (dim_a, dim_b)
Cost matrix
+
Returns
-------
H : np.ndarray (dim_a+dim_b, dim_a*dim_b)
- Auxiliary matrix constituted by 0 and 1
+ Design matrix that contains only 0 and 1
y : np.ndarray (ns + nt, )
- Concatenation of histogram a and histogram b
+ Concatenation of histograms :math:`\mathbf{a}` and :math:`\mathbf{b}`
c : np.ndarray (ns * nt, )
- Flattened array of cost matrix
+ Flattened array of the cost matrix
+
Examples
--------
>>> import ot
@@ -73,12 +89,12 @@ def recast_ot_as_lasso(a, b, C):
>>> c
array([16., 25., 28., 16., 40., 36.])
+
References
----------
- [Chapel et al., 2021]:
- Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
Unbalanced optimal transport through non-negative penalized
- linear regression.
+ linear regression. NeurIPS.
"""
dim_a = np.shape(a)[0]
@@ -97,33 +113,47 @@ def recast_ot_as_lasso(a, b, C):
def recast_semi_relaxed_as_lasso(a, b, C):
- r"""This function recasts the semi-relaxed l2-UOT problem as Lasso problem
+ r"""This function recasts the semi-relaxed l2-UOT problem as Lasso problem.
.. math::
- semi-relaxed UOT = \min_T <C, T> + \lambda \|T 1_m - a\|_2^2
+
+ \text{semi-relaxed UOT} = \min_T <C, T>
+ + \lambda \|T 1_m - \mathbf{a}\|_2^2
+
s.t.
- T^T 1_n = b
- t \geq 0
+ T^T 1_n = \mathbf{b}
+
+ \mathbf{t} \geq 0
+
where :
- - C is the (dim_a, dim_b) metric cost matrix
- - :math:`\lambda` is the l2-regularization coefficient
- - a and b are source and target distributions
- - T is the transport plan to optimize
+
+ - :math:`C` is the metric cost matrix
+ - :math:`\lambda` is the l2-regularization parameter
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the source and target \
+ distributions
+ - :math:`T` is the transport plan to optimize
The problem above can be reformulated as follows
+
.. math::
- semi-relaxed UOT2 = \min_t \gamma c^T t + 0.5 * \|H_r t - a\|_2^2
+ \text{semi-relaxed UOT2} = \min_t \gamma \mathbf{c}^T t
+ + 0.5 * \|H_r \mathbf{t} - \mathbf{a}\|_2^2
+
s.t.
- H_c t = b
- t \geq 0
+ H_c \mathbf{t} = \mathbf{b}
+
+ \mathbf{t} \geq 0
+
where :
- - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C)
- - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient
- - H_r is a (dim_a, dim_a * dim_b) metric matrix,
- which computes the sum along the rows of transport plan T
- - H_c is a (dim_b, dim_a * dim_b) metric matrix,
- which computes the sum along the columns of transport plan T
- - t is a (dim_a * dim_b, ) metric vector (flattened version of T)
+
+ - :math:`\mathbf{c}` is flattened version of the cost matrix :math:`C`
+ - :math:`\gamma = 1/\lambda` is the l2-regularization parameter
+ - :math:`H_r` is a metric matrix which computes the sum along the \
+ rows of the transport plan :math:`T`
+ - :math:`H_c` is a metric matrix which computes the sum along the \
+ columns of the transport plan :math:`T`
+ - :math:`\mathbf{t}` is the flattened version of :math:`T`
+
Parameters
----------
a : np.ndarray (dim_a,)
@@ -132,16 +162,18 @@ def recast_semi_relaxed_as_lasso(a, b, C):
Histogram of dimension dim_b
C : np.ndarray, shape (dim_a, dim_b)
Cost matrix
+
Returns
-------
Hr : np.ndarray (dim_a, dim_a * dim_b)
Auxiliary matrix constituted by 0 and 1, which computes
- the sum along the rows of transport plan T
+ the sum along the rows of transport plan :math:`T`
Hc : np.ndarray (dim_b, dim_a * dim_b)
Auxiliary matrix constituted by 0 and 1, which computes
- the sum along the columns of transport plan T
+ the sum along the columns of transport plan :math:`T`
c : np.ndarray (ns * nt, )
- Flattened array of cost matrix
+ Flattened array of the cost matrix
+
Examples
--------
>>> import ot
@@ -179,49 +211,60 @@ def recast_semi_relaxed_as_lasso(a, b, C):
def ot_next_gamma(phi, delta, HtH, Hty, c, active_index, current_gamma):
r""" This function computes the next value of gamma if a variable
- will be added in next iteration of the regularization path
+ is added in the next iteration of the regularization path.
We look for the largest value of gamma such that
the gradient of an inactive variable vanishes
+
.. math::
- \max_{i \in \bar{A}} \frac{h_i^T(H_A \phi - y)}{h_i^T H_A \delta - c_i}
+ \max_{i \in \bar{A}} \frac{\mathbf{h}_i^T(H_A \phi - \mathbf{y})}
+ {\mathbf{h}_i^T H_A \delta - \mathbf{c}_i}
+
where :
+
- A is the current active set
- - h_i is the ith column of auxiliary matrix H
- - H_A is the sub-matrix constructed by the columns of H
- whose indices belong to the active set A
- - c_i is the ith element of cost vector c
- - y is the concatenation of source and target distribution
- - :math:`\phi` is the intercept of the solutions in current iteration
- - :math:`\delta` is the slope of the solutions in current iteration
+ - :math:`\mathbf{h}_i` is the :math:`i` th column of the design \
+ matrix :math:`{H}`
+ - :math:`{H}_A` is the sub-matrix constructed by the columns of \
+ :math:`{H}` whose indices belong to the active set A
+ - :math:`\mathbf{c}_i` is the :math:`i` th element of the cost vector \
+ :math:`\mathbf{c}`
+ - :math:`\mathbf{y}` is the concatenation of the source and target \
+ distributions
+ - :math:`\phi` is the intercept of the solutions at the current iteration
+ - :math:`\delta` is the slope of the solutions at the current iteration
+
Parameters
----------
- phi : np.ndarray (|A|, )
- Intercept of the solutions in current iteration (t is piecewise linear)
- delta : np.ndarray (|A|, )
- Slope of the solutions in current iteration (t is piecewise linear)
+ phi : np.ndarray (size(A), )
+ Intercept of the solutions at the current iteration
+ delta : np.ndarray (size(A), )
+ Slope of the solutions at the current iteration
HtH : np.ndarray (dim_a * dim_b, dim_a * dim_b)
- Matrix product of H^T H
+ Matrix product of :math:`{H}^T {H}`
Hty : np.ndarray (dim_a + dim_b, )
- Matrix product of H^T y
+ Matrix product of :math:`{H}^T \mathbf{y}`
c: np.ndarray (dim_a * dim_b, )
- Flattened array of cost matrix C
+ Flattened array of the cost matrix :math:`{C}`
active_index : list
Indices of active variables
current_gamma : float
- Value of regularization coefficient at the start of current iteration
+ Value of the regularization parameter at the beginning of the current \
+ iteration
+
Returns
-------
next_gamma : float
Value of gamma if a variable is added to active set in next iteration
next_active_index : int
Index of variable to be activated
+
+
References
----------
- [Chapel et al., 2021]:
- Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
Unbalanced optimal transport through non-negative penalized
- linear regression.
+ linear regression. NeurIPS.
"""
M = (HtH[:, active_index].dot(phi) - Hty) / \
(HtH[:, active_index].dot(delta) - c + 1e-16)
@@ -237,56 +280,65 @@ def semi_relaxed_next_gamma(phi, delta, phi_u, delta_u, HrHr, Hc, Hra,
By taking the Lagrangian form of the problem, we obtain a similar update
as the two-sided relaxed UOT
+
.. math::
- \max_{i \in \bar{A}} \frac{h_{r i}^T(H_{r A} \phi - a) + h_{c i}^T
- \phi_u}{h_{r i}^T H_{r A} \delta + h_{c i} \delta_u - c_i}
+
+ \max_{i \in \bar{A}} \frac{\mathbf{h}_{ri}^T(H_{rA} \phi - \mathbf{a})
+ + \mathbf{h}_{c i}^T\phi_u}{\mathbf{h}_{r i}^T H_{r A} \delta + \
+ \mathbf{h}_{c i} \delta_u - \mathbf{c}_i}
+
where :
+
- A is the current active set
- - h_{r i} is the ith column of the matrix H_r
- - h_{c i} is the ith column of the matrix H_c
- - H_{r A} is the sub-matrix constructed by the columns of H_r
- whose indices belong to the active set A
- - c_i is the ith element of cost vector c
- - y is the concatenation of source and target distribution
+ - :math:`\mathbf{h}_{r i}` is the ith column of the matrix :math:`H_r`
+ - :math:`\mathbf{h}_{c i}` is the ith column of the matrix :math:`H_c`
+ - :math:`H_{r A}` is the sub-matrix constructed by the columns of \
+ :math:`H_r` whose indices belong to the active set A
+ - :math:`\mathbf{c}_i` is the :math:`i` th element of cost vector \
+ :math:`\mathbf{c}`
- :math:`\phi` is the intercept of the solutions in current iteration
- :math:`\delta` is the slope of the solutions in current iteration
- - :math:`\phi_u` is the intercept of Lagrange parameter in current
- iteration
- - :math:`\delta_u` is the slope of Lagrange parameter in current iteration
+ - :math:`\phi_u` is the intercept of Lagrange parameter at the \
+ current iteration
+ - :math:`\delta_u` is the slope of Lagrange parameter at the \
+ current iteration
+
Parameters
----------
- phi : np.ndarray (|A|, )
- Intercept of the solutions in current iteration (t is piecewise linear)
- delta : np.ndarray (|A|, )
- Slope of the solutions in current iteration (t is piecewise linear)
+ phi : np.ndarray (size(A), )
+ Intercept of the solutions at the current iteration
+ delta : np.ndarray (size(A), )
+ Slope of the solutions at the current iteration
phi_u : np.ndarray (dim_b, )
- Intercept of the Lagrange parameter in current iteration (also linear)
+ Intercept of the Lagrange parameter at the current iteration
delta_u : np.ndarray (dim_b, )
- Slope of the Lagrange parameter in current iteration (also linear)
+ Slope of the Lagrange parameter at the current iteration
HrHr : np.ndarray (dim_a * dim_b, dim_a * dim_b)
- Matrix product of H_r^T H_r
+ Matrix product of :math:`H_r^T H_r`
Hc : np.ndarray (dim_b, dim_a * dim_b)
- Matrix that computes the sum along the columns of transport plan T
+ Matrix that computes the sum along the columns of the transport plan \
+ :math:`T`
Hra : np.ndarray (dim_a * dim_b, )
- Matrix product of H_r^T a
+ Matrix product of :math:`H_r^T \mathbf{a}`
c: np.ndarray (dim_a * dim_b, )
- Flattened array of cost matrix C
+ Flattened array of cost matrix :math:`C`
active_index : list
Indices of active variables
current_gamma : float
Value of regularization coefficient at the start of current iteration
+
Returns
-------
next_gamma : float
Value of gamma if a variable is added to active set in next iteration
next_active_index : int
Index of variable to be activated
+
References
----------
- [Chapel et al., 2021]:
- Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
Unbalanced optimal transport through non-negative penalized
- linear regression.
+ linear regression. NeurIPS.
"""
M = (HrHr[:, active_index].dot(phi) - Hra + Hc.T.dot(phi_u)) / \
@@ -297,37 +349,48 @@ def semi_relaxed_next_gamma(phi, delta, phi_u, delta_u, HrHr, Hc, Hra,
def compute_next_removal(phi, delta, current_gamma):
- r""" This function computes the next value of gamma if a variable
- is removed in next iteration of regularization path
+ r""" This function computes the next gamma value if a variable
+ is removed at the next iteration of the regularization path.
+
+ We look for the largest value of the regularization parameter such that
+ an element of the current solution vanishes
- We look for the largest value of gamma such that
- an element of current solution vanishes
.. math::
\max_{j \in A} \frac{\phi_j}{\delta_j}
+
where :
+
- A is the current active set
- - phi_j is the jth element of the intercept of current solution
- - delta_j is the jth elemnt of the slope of current solution
+ - :math:`\phi_j` is the :math:`j` th element of the intercept of the \
+ current solution
+ - :math:`\delta_j` is the :math:`j` th element of the slope of the \
+ current solution
+
+
Parameters
----------
- phi : np.ndarray (|A|, )
- Intercept of the solutions in current iteration (t is piecewise linear)
- delta : np.ndarray (|A|, )
- Slope of the solutions in current iteration (t is piecewise linear)
+ phi : ndarray, shape (size(A), )
+ Intercept of the solution at the current iteration
+ delta : ndarray, shape (size(A), )
+ Slope of the solution at the current iteration
current_gamma : float
- Value of regularization coefficient at the start of current iteration
+ Value of the regularization parameter at the beginning of the \
+ current iteration
+
Returns
-------
next_removal_gamma : float
- Value of gamma if a variable is removed in next iteration
+ Gamma value if a variable is removed at the next iteration
next_removal_index : int
- Index of the variable to remove in next iteration
+ Index of the variable to be removed at the next iteration
+
+
+ .. _references-regpath:
References
----------
- [Chapel et al., 2021]:
- Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
Unbalanced optimal transport through non-negative penalized
- linear regression.
+ linear regression. NeurIPS.
"""
r_candidate = phi / (delta - 1e-16)
r_candidate[r_candidate >= (1 - 1e-8) * current_gamma] = 0
@@ -335,56 +398,74 @@ def compute_next_removal(phi, delta, current_gamma):
def complement_schur(M_current, b, d, id_pop):
- r""" This function computes the inverse of matrix in regularization path
- using Schur complement
+ r""" This function computes the inverse of the design matrix in the \
+ regularization path using the Schur complement. Two cases may arise:
+
+ Case 1: one variable is added to the active set
+
- Two cases may arise: Firstly one variable is added to the active set
.. math::
M_{k+1}^{-1} =
\begin{bmatrix}
- M_{k}^{-1} + s^{-1} M_{k}^{-1} b b^T M_{k}^{-1} & -s^{-1} \\
- - s^{-1} b^T M_{k}^{-1} & s^{-1}
+ M_{k}^{-1} + s^{-1} M_{k}^{-1} \mathbf{b} \mathbf{b}^T M_{k}^{-1} \
+ & - M_{k}^{-1} \mathbf{b} s^{-1} \\
+ - s^{-1} \mathbf{b}^T M_{k}^{-1} & s^{-1}
\end{bmatrix}
+
+
where :
- - :math:`M_k^{-1}` is the inverse of matrix in previous iteration and
- :math:`M_k` is the upper left block matrix in Schur formulation
- - b is the upper right block matrix in Schur formulation. In our case,
- b is reduced to a column vector and b^T is the lower left block matrix
- - s is the Schur complement, given by
- :math:`s = d - b^T M_{k}^{-1} b` in our case
-
- Secondly, one variable is removed from the active set
+
+ - :math:`M_k^{-1}` is the inverse of the design matrix :math:`H_A^tH_A` \
+ of the previous iteration
+ - :math:`\mathbf{b}` is the last column of :math:`M_{k}`
+ - :math:`s` is the Schur complement, given by \
+ :math:`s = \mathbf{d} - \mathbf{b}^T M_{k}^{-1} \mathbf{b}`
+
+ Case 2: one variable is removed from the active set.
+
.. math::
- M_{k+1}^{-1} = M^{-1}_{A_k \backslash q} -
+ M_{k+1}^{-1} = M^{-1}_{k \backslash q} -
\frac{r_{-q,q} r^{T}_{-q,q}}{r_{q,q}}
+
where :
- - q is the index of column and row to delete
- - :math:`M^{-1}_{A_k \backslash q}` is the previous inverse matrix
- without qth column and qth row
- - r_{-q,q} is the qth column of :math:`M^{-1}_{k}` without the qth element
- - r_{q, q} is the element of qth column and qth row in :math:`M^{-1}_{k}`
+
+ - :math:`q` is the index of column and row to delete
+ - :math:`M^{-1}_{k \backslash q}` is the previous inverse matrix deprived \
+ of the :math:`q` th column and :math:`q` th row
+ - :math:`r_{-q,q}` is the :math:`q` th column of :math:`M^{-1}_{k}` \
+ without the :math:`q` th element
+ - :math:`r_{q, q}` is the element of :math:`q` th column and :math:`q` th \
+ row in :math:`M^{-1}_{k}`
+
+
Parameters
----------
- M_current : np.ndarray (|A|-1, |A|-1)
- Inverse matrix in previous iteration
- b : np.ndarray (|A|-1, )
- Upper right matrix in Schur complement, a column vector in our case
+ M_current : ndarray, shape (size(A)-1, size(A)-1)
+ Inverse matrix of :math:`H_A^tH_A` at the previous iteration, with \
+ size(A) the size of the active set
+ b : ndarray, shape (size(A)-1, )
+ None for case 2 (removal), last column of :math:`M_{k}` for case 1 \
+ (addition)
d : float
- Lower right matrix in Schur complement, a scalar in our case
- id_pop
+ should be equal to 2 when UOT and 1 for the semi-relaxed OT
+ id_pop : int
Index of the variable to be removed, equal to -1
- if none of the variables is deleted in current iteration
+ if no variable is deleted at the current iteration
+
+
Returns
-------
- M : np.ndarray (|A|, |A|)
- Inverse matrix needed in current iteration
+ M : ndarray, shape (size(A), size(A))
+ Inverse matrix of :math:`H_A^tH_A` of the current iteration
+
+
References
----------
- [Chapel et al., 2021]:
- Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
Unbalanced optimal transport through non-negative penalized
- linear regression.
+ linear regression. NeurIPS.
"""
+
if b is None:
b = M_current[id_pop, :]
b = np.delete(b, id_pop)
@@ -409,33 +490,39 @@ def complement_schur(M_current, b, d, id_pop):
def construct_augmented_H(active_index, m, Hc, HrHr):
- r""" This function construct an augmented matrix for the first iteration of
- semi-relaxed regularization path
+ r""" This function constructs an augmented matrix for the first iteration
+ of the semi-relaxed regularization path
.. math::
- Augmented_H =
+ \text{Augmented}_H =
\begin{bmatrix}
0 & H_{c A} \\
H_{c A}^T & H_{r A}^T H_{r A}
\end{bmatrix}
+
where :
- - H_{r A} is the sub-matrix constructed by the columns of H_r
- whose indices belong to the active set A
- - H_{c A} is the sub-matrix constructed by the columns of H_c
- whose indices belong to the active set A
+
+ - :math:`H_{r A}` is the sub-matrix constructed by the columns of \
+ :math:`H_r` whose indices belong to the active set A
+ - :math:`H_{c A}` is the sub-matrix constructed by the columns of \
+ :math:`H_c` whose indices belong to the active set A
+
+
Parameters
----------
active_index : list
- Indices of active variables
+ Indices of the active variables
m : int
Length of the target distribution
Hc : np.ndarray (dim_b, dim_a * dim_b)
- Matrix that computes the sum along the columns of transport plan T
+ Matrix that computes the sum along the columns of the transport plan \
+ :math:`T`
HrHr : np.ndarray (dim_a * dim_b, dim_a * dim_b)
- Matrix product of H_r^T H_r
+ Matrix product of :math:`H_r^T H_r`
+
Returns
-------
- H_augmented : np.ndarray (dim_b + |A|, dim_b + |A|)
+ H_augmented : np.ndarray (dim_b + size(A), dim_b + size(A))
Augmented matrix for the first iteration of the semi-relaxed
regularization path
"""
@@ -451,18 +538,27 @@ def fully_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
r"""This function gives the regularization path of l2-penalized UOT problem
The problem to optimize is the Lasso reformulation of the l2-penalized UOT:
+
.. math::
- \min_t \gamma c^T t + 0.5 * \|H t - y\|_2^2
+ \min_t \gamma \mathbf{c}^T \mathbf{t}
+ + 0.5 * \|{H} \mathbf{t} - \mathbf{y}\|_2^2
+
s.t.
- t \geq 0
+ \mathbf{t} \geq 0
+
where :
- - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C)
+
+ - :math:`\mathbf{c}` is the flattened version of the cost matrix \
+ :math:`{C}`
- :math:`\gamma = 1/\lambda` is the l2-regularization coefficient
- - y is the concatenation of vectors a and b, defined as y^T = [a^T b^T]
- - H is a (dim_a + dim_b, dim_a * dim_b) metric matrix,
- see [Chapel et al., 2021] for the design of H. The matrix product Ht
- computes both the source marginal and the target marginal.
- - t is a (dim_a * dim_b, ) metric vector (flattened version of T)
+ - :math:`\mathbf{y}` is the concatenation of vectors :math:`\mathbf{a}` \
+ and :math:`\mathbf{b}`, defined as \
+ :math:`\mathbf{y}^T = [\mathbf{a}^T \mathbf{b}^T]`
+ - :math:`{H}` is a design matrix, see :ref:`[41] <references-regpath>` \
+ for the design of :math:`{H}`. The matrix product :math:`H\mathbf{t}` \
+ computes both the source marginal and the target marginals.
+ - :math:`\mathbf{t}` is the flattened version of the transport matrix
+
Parameters
----------
a : np.ndarray (dim_a,)
@@ -478,11 +574,12 @@ def fully_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
Returns
-------
t : np.ndarray (dim_a*dim_b, )
- Flattened vector of optimal transport matrix
+ Flattened vector of the optimal transport matrix
t_list : list
- List of solutions in regularization path
+ List of solutions in the regularization path
gamma_list : list
- List of regularization coefficient in regularization path
+ List of regularization coefficients in the regularization path
+
Examples
--------
>>> import ot
@@ -502,10 +599,9 @@ def fully_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
References
----------
- [Chapel et al., 2021]:
- Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
Unbalanced optimal transport through non-negative penalized
- linear regression.
+ linear regression. NeurIPS.
"""
n = np.shape(a)[0]
@@ -580,22 +676,32 @@ def fully_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
def semi_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
itmax=50000):
r"""This function gives the regularization path of semi-relaxed
- l2-UOT problem
+ l2-UOT problem.
The problem to optimize is the Lasso reformulation of the l2-penalized UOT:
+
.. math::
- \min_t \gamma c^T t + 0.5 * \|H_r t - a\|_2^2
+
+ \min_t \gamma \mathbf{c}^T t
+ + 0.5 * \|H_r \mathbf{t} - \mathbf{a}\|_2^2
+
s.t.
- H_c t = b
- t \geq 0
+ H_c \mathbf{t} = \mathbf{b}
+
+ \mathbf{t} \geq 0
+
where :
- - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C)
- - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient
- - H_r is a (dim_a, dim_a * dim_b) metric matrix,
- which computes the sum along the rows of transport plan T
- - H_c is a (dim_b, dim_a * dim_b) metric matrix,
- which computes the sum along the columns of transport plan T
- - t is a (dim_a * dim_b, ) metric vector (flattened version of T)
+
+ - :math:`\mathbf{c}` is the flattened version of the cost matrix \
+ :math:`C`
+ - :math:`\gamma = 1/\lambda` is the l2-regularization parameter
+ - :math:`H_r` is a matrix that computes the sum along the rows of \
+ the transport plan :math:`T`
+ - :math:`H_c` is a matrix that computes the sum along the columns of \
+ the transport plan :math:`T`
+ - :math:`\mathbf{t}` is the flattened version of the transport plan \
+ :math:`T`
+
Parameters
----------
a : np.ndarray (dim_a,)
@@ -608,14 +714,16 @@ def semi_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
l2-regularization coefficient
itmax: int (optional)
Maximum number of iteration
+
Returns
-------
t : np.ndarray (dim_a*dim_b, )
- Flattened vector of optimal transport matrix
+ Flattened vector of the (unregularized) optimal transport matrix
t_list : list
- List of solutions in regularization path
+ List of all the optimal transport vectors of the regularization path
gamma_list : list
- List of regularization coefficient in regularization path
+ List of the regularization parameters in the path
+
Examples
--------
>>> import ot
@@ -635,10 +743,9 @@ def semi_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
References
----------
- [Chapel et al., 2021]:
- Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
Unbalanced optimal transport through non-negative penalized
- linear regression.
+ linear regression. NeurIPS.
"""
n = np.shape(a)[0]
@@ -722,8 +829,44 @@ def semi_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
def regularization_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
semi_relaxed=False, itmax=50000):
- r"""This function combines both the semi-relaxed and the fully-relaxed
- regularization paths of l2-UOT problem
+ r"""This function provides all the solutions of the regularization path \
+ of the l2-UOT problem :ref:`[41] <references-regpath>`.
+
+ The problem to optimize is the Lasso reformulation of the l2-penalized UOT:
+
+ .. math::
+ \min_t \gamma \mathbf{c}^T \mathbf{t}
+ + 0.5 * \|{H} \mathbf{t} - \mathbf{y}\|_2^2
+
+ s.t.
+ \mathbf{t} \geq 0
+
+ where :
+
+ - :math:`\mathbf{c}` is the flattened version of the cost matrix \
+ :math:`{C}`
+ - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient
+ - :math:`\mathbf{y}` is the concatenation of vectors :math:`\mathbf{a}` \
+ and :math:`\mathbf{b}`, defined as \
+ :math:`\mathbf{y}^T = [\mathbf{a}^T \mathbf{b}^T]`
+ - :math:`{H}` is a design matrix, see :ref:`[41] <references-regpath>` \
+ for the design of :math:`{H}`. The matrix product :math:`H\mathbf{t}` \
+ computes both the source marginal and the target marginals.
+ - :math:`\mathbf{t}` is the flattened version of the transport matrix
+
+ For the semi-relaxed problem, it optimizes the Lasso reformulation of the
+ l2-penalized UOT:
+
+ .. math::
+
+ \min_t \gamma \mathbf{c}^T \mathbf{t}
+ + 0.5 * \|H_r \mathbf{t} - \mathbf{a}\|_2^2
+
+ s.t.
+ H_c \mathbf{t} = \mathbf{b}
+
+ \mathbf{t} \geq 0
+
Parameters
----------
@@ -736,23 +879,24 @@ def regularization_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
reg: float (optional)
l2-regularization coefficient
semi_relaxed : bool (optional)
- Give the semi-relaxed path if true
+ Give the semi-relaxed path if True
itmax: int (optional)
Maximum number of iteration
+
Returns
-------
t : np.ndarray (dim_a*dim_b, )
- Flattened vector of optimal transport matrix
+ Flattened vector of the (unregularized) optimal transport matrix
t_list : list
- List of solutions in regularization path
+ List of all the optimal transport vectors of the regularization path
gamma_list : list
- List of regularization coefficient in regularization path
+ List of the regularization parameters in the path
+
References
----------
- [Chapel et al., 2021]:
- Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
Unbalanced optimal transport through non-negative penalized
- linear regression.
+ linear regression. NeurIPS.
"""
if semi_relaxed:
t, t_list, gamma_list = semi_relaxed_path(a, b, C, reg=reg,
@@ -765,27 +909,33 @@ def regularization_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
def compute_transport_plan(gamma, gamma_list, Pi_list):
r""" Given the regularization path, this function computes the transport
- plan for any value of gamma by the piecewise linearity of the path
+ plan for any value of gamma thanks to the piecewise linearity of the path.
.. math::
t(\gamma) = \phi(\gamma) - \gamma \delta(\gamma)
- where :
- - :math:`\gamma` is the regularization coefficient
+
+ where:
+
+ - :math:`\gamma` is the regularization parameter
- :math:`\phi(\gamma)` is the corresponding intercept
- :math:`\delta(\gamma)` is the corresponding slope
- - t is a (dim_a * dim_b, ) vector (flattened version of transport matrix)
+ - :math:`\mathbf{t}` is the flattened version of the transport matrix
+
Parameters
----------
gamma : float
Regularization coefficient
gamma_list : list
- List of regularization coefficients in regularization path
+ List of regularization parameters of the regularization path
Pi_list : list
- List of solutions in regularization path
+ List of all the solutions of the regularization path
+
Returns
-------
t : np.ndarray (dim_a*dim_b, )
- Transport vector corresponding to the given value of gamma
+ Vectorization of the transport plan corresponding to the given value
+ of gamma
+
Examples
--------
>>> import ot
@@ -804,12 +954,13 @@ def compute_transport_plan(gamma, gamma_list, Pi_list):
array([0. , 0. , 0. , 0.19722222, 0.05555556,
0. , 0. , 0.24722222, 0. ])
+
+ .. _references-regpath:
References
----------
- [Chapel et al., 2021]:
- Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
Unbalanced optimal transport through non-negative penalized
- linear regression.
+ linear regression. NeurIPS.
"""
if gamma >= gamma_list[0]:
diff --git a/ot/stochastic.py b/ot/stochastic.py
index 693675f..61be9bb 100644
--- a/ot/stochastic.py
+++ b/ot/stochastic.py
@@ -4,12 +4,14 @@ Stochastic solvers for regularized OT.
"""
-# Author: Kilian Fatras <kilian.fatras@gmail.com>
+# Authors: Kilian Fatras <kilian.fatras@gmail.com>
+# Rémi Flamary <remi.flamary@polytechnique.edu>
#
# License: MIT License
import numpy as np
-
+from .utils import dist
+from .backend import get_backend
##############################################################################
# Optimization toolbox for SEMI - DUAL problems
@@ -747,3 +749,239 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1,
return pi, log
else:
return pi
+
+
+################################################################################
+# Losses for stochastic optimization
+################################################################################
+
+def loss_dual_entropic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidean'):
+ r"""
+ Compute the dual loss of the entropic OT as in equation (6)-(7) of [19]
+
+ This loss is backend compatible and can be used for stochastic optimization
+ of the dual potentials. It can be used on the full dataset (beware of
+ memory) or on minibatches.
+
+
+ Parameters
+ ----------
+ u : array-like, shape (ns,)
+ Source dual potential
+ v : array-like, shape (nt,)
+ Target dual potential
+ xs : array-like, shape (ns,d)
+ Source samples
+ xt : array-like, shape (ns,d)
+ Target samples
+ reg : float
+ Regularization term > 0 (default=1)
+ ws : array-like, shape (ns,), optional
+ Source sample weights (default unif)
+ wt : array-like, shape (ns,), optional
+ Target sample weights (default unif)
+ metric : string, callable
+ Ground metric for OT (default quadratic). Can be given as a callable
+ function taking (xs,xt) as parameters.
+
+ Returns
+ -------
+ dual_loss : array-like
+ Dual loss (to maximize)
+
+
+ References
+ ----------
+ .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018)
+ """
+
+ nx = get_backend(u, v, xs, xt)
+
+ if ws is None:
+ ws = nx.ones(xs.shape[0], type_as=xs) / xs.shape[0]
+
+ if wt is None:
+ wt = nx.ones(xt.shape[0], type_as=xt) / xt.shape[0]
+
+ if callable(metric):
+ M = metric(xs, xt)
+ else:
+ M = dist(xs, xt, metric=metric)
+
+ F = -reg * nx.exp((u[:, None] + v[None, :] - M) / reg)
+
+ return nx.sum(u * ws) + nx.sum(v * wt) + nx.sum(ws[:, None] * F * wt[None, :])
+
+
+def plan_dual_entropic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidean'):
+ r"""
+ Compute the primal OT plan the entropic OT as in equation (8) of [19]
+
+ This loss is backend compatible and can be used for stochastic optimization
+ of the dual potentials. It can be used on the full dataset (beware of
+ memory) or on minibatches.
+
+
+ Parameters
+ ----------
+ u : array-like, shape (ns,)
+ Source dual potential
+ v : array-like, shape (nt,)
+ Target dual potential
+ xs : array-like, shape (ns,d)
+ Source samples
+ xt : array-like, shape (ns,d)
+ Target samples
+ reg : float
+ Regularization term > 0 (default=1)
+ ws : array-like, shape (ns,), optional
+ Source sample weights (default unif)
+ wt : array-like, shape (ns,), optional
+ Target sample weights (default unif)
+ metric : string, callable
+ Ground metric for OT (default quadratic). Can be given as a callable
+ function taking (xs,xt) as parameters.
+
+ Returns
+ -------
+ G : array-like
+ Primal OT plan
+
+
+ References
+ ----------
+ .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018)
+ """
+
+ nx = get_backend(u, v, xs, xt)
+
+ if ws is None:
+ ws = nx.ones(xs.shape[0], type_as=xs) / xs.shape[0]
+
+ if wt is None:
+ wt = nx.ones(xt.shape[0], type_as=xt) / xt.shape[0]
+
+ if callable(metric):
+ M = metric(xs, xt)
+ else:
+ M = dist(xs, xt, metric=metric)
+
+ H = nx.exp((u[:, None] + v[None, :] - M) / reg)
+
+ return ws[:, None] * H * wt[None, :]
+
+
+def loss_dual_quadratic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidean'):
+ r"""
+ Compute the dual loss of the quadratic regularized OT as in equation (6)-(7) of [19]
+
+ This loss is backend compatible and can be used for stochastic optimization
+ of the dual potentials. It can be used on the full dataset (beware of
+ memory) or on minibatches.
+
+
+ Parameters
+ ----------
+ u : array-like, shape (ns,)
+ Source dual potential
+ v : array-like, shape (nt,)
+ Target dual potential
+ xs : array-like, shape (ns,d)
+ Source samples
+ xt : array-like, shape (ns,d)
+ Target samples
+ reg : float
+ Regularization term > 0 (default=1)
+ ws : array-like, shape (ns,), optional
+ Source sample weights (default unif)
+ wt : array-like, shape (ns,), optional
+ Target sample weights (default unif)
+ metric : string, callable
+ Ground metric for OT (default quadratic). Can be given as a callable
+ function taking (xs,xt) as parameters.
+
+ Returns
+ -------
+ dual_loss : array-like
+ Dual loss (to maximize)
+
+
+ References
+ ----------
+ .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018)
+ """
+
+ nx = get_backend(u, v, xs, xt)
+
+ if ws is None:
+ ws = nx.ones(xs.shape[0], type_as=xs) / xs.shape[0]
+
+ if wt is None:
+ wt = nx.ones(xt.shape[0], type_as=xt) / xt.shape[0]
+
+ if callable(metric):
+ M = metric(xs, xt)
+ else:
+ M = dist(xs, xt, metric=metric)
+
+ F = -1.0 / (4 * reg) * nx.maximum(u[:, None] + v[None, :] - M, 0.0)**2
+
+ return nx.sum(u * ws) + nx.sum(v * wt) + nx.sum(ws[:, None] * F * wt[None, :])
+
+
+def plan_dual_quadratic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidean'):
+ r"""
+ Compute the primal OT plan the quadratic regularized OT as in equation (8) of [19]
+
+ This loss is backend compatible and can be used for stochastic optimization
+ of the dual potentials. It can be used on the full dataset (beware of
+ memory) or on minibatches.
+
+
+ Parameters
+ ----------
+ u : array-like, shape (ns,)
+ Source dual potential
+ v : array-like, shape (nt,)
+ Target dual potential
+ xs : array-like, shape (ns,d)
+ Source samples
+ xt : array-like, shape (ns,d)
+ Target samples
+ reg : float
+ Regularization term > 0 (default=1)
+ ws : array-like, shape (ns,), optional
+ Source sample weights (default unif)
+ wt : array-like, shape (ns,), optional
+ Target sample weights (default unif)
+ metric : string, callable
+ Ground metric for OT (default quadratic). Can be given as a callable
+ function taking (xs,xt) as parameters.
+
+ Returns
+ -------
+ G : array-like
+ Primal OT plan
+
+
+ References
+ ----------
+ .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018)
+ """
+
+ nx = get_backend(u, v, xs, xt)
+
+ if ws is None:
+ ws = nx.ones(xs.shape[0], type_as=xs) / xs.shape[0]
+
+ if wt is None:
+ wt = nx.ones(xt.shape[0], type_as=xt) / xt.shape[0]
+
+ if callable(metric):
+ M = metric(xs, xt)
+ else:
+ M = dist(xs, xt, metric=metric)
+
+ H = 1.0 / (2 * reg) * nx.maximum(u[:, None] + v[None, :] - M, 0.0)
+
+ return ws[:, None] * H * wt[None, :]
diff --git a/ot/unbalanced.py b/ot/unbalanced.py
index 15e180b..90c920c 100644
--- a/ot/unbalanced.py
+++ b/ot/unbalanced.py
@@ -4,13 +4,14 @@ Regularized Unbalanced OT solvers
"""
# Author: Hicham Janati <hicham.janati@inria.fr>
+# Laetitia Chapel <laetitia.chapel@univ-ubs.fr>
# License: MIT License
from __future__ import division
import warnings
-import numpy as np
-from scipy.special import logsumexp
+from .backend import get_backend
+from .utils import list_to_array
# from .utils import unif, dist
@@ -43,12 +44,12 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000,
Parameters
----------
- a : np.ndarray (dim_a,)
+ a : array-like (dim_a,)
Unnormalized histogram of dimension `dim_a`
- b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ b : array-like (dim_b,) or array-like (dim_b, n_hists)
One or multiple unnormalized histograms of dimension `dim_b`.
If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i`
- M : np.ndarray (dim_a, dim_b)
+ M : array-like (dim_a, dim_b)
loss matrix
reg : float
Entropy regularization term > 0
@@ -70,12 +71,12 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000,
Returns
-------
if n_hists == 1:
- - gamma : (dim_a, dim_b) ndarray
+ - gamma : (dim_a, dim_b) array-like
Optimal transportation matrix for the given parameters
- log : dict
log dictionary returned only if `log` is `True`
else:
- - ot_distance : (n_hists,) ndarray
+ - ot_distance : (n_hists,) array-like
the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i`
- log : dict
log dictionary returned only if `log` is `True`
@@ -172,12 +173,12 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
Parameters
----------
- a : np.ndarray (dim_a,)
+ a : array-like (dim_a,)
Unnormalized histogram of dimension `dim_a`
- b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ b : array-like (dim_b,) or array-like (dim_b, n_hists)
One or multiple unnormalized histograms of dimension `dim_b`.
If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i`
- M : np.ndarray (dim_a, dim_b)
+ M : array-like (dim_a, dim_b)
loss matrix
reg : float
Entropy regularization term > 0
@@ -198,7 +199,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
Returns
-------
- ot_distance : (n_hists,) ndarray
+ ot_distance : (n_hists,) array-like
the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i`
log : dict
log dictionary returned only if `log` is `True`
@@ -239,9 +240,10 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epslilon scaling :ref:`[9, 10] <references-sinkhorn-unbalanced2>`
"""
- b = np.asarray(b, dtype=np.float64)
+ b = list_to_array(b)
if len(b.shape) < 2:
b = b[:, None]
+
if method.lower() == 'sinkhorn':
return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
numItermax=numItermax,
@@ -291,12 +293,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
Parameters
----------
- a : np.ndarray (dim_a,)
+ a : array-like (dim_a,)
Unnormalized histogram of dimension `dim_a`
- b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ b : array-like (dim_b,) or array-like (dim_b, n_hists)
One or multiple unnormalized histograms of dimension `dim_b`
If many, compute all the OT distances (a, b_i)
- M : np.ndarray (dim_a, dim_b)
+ M : array-like (dim_a, dim_b)
loss matrix
reg : float
Entropy regularization term > 0
@@ -315,12 +317,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
Returns
-------
if n_hists == 1:
- - gamma : (dim_a, dim_b) ndarray
+ - gamma : (dim_a, dim_b) array-like
Optimal transportation matrix for the given parameters
- log : dict
log dictionary returned only if `log` is `True`
else:
- - ot_distance : (n_hists,) ndarray
+ - ot_distance : (n_hists,) array-like
the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i`
- log : dict
log dictionary returned only if `log` is `True`
@@ -354,17 +356,15 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
ot.optim.cg : General regularized OT
"""
-
- a = np.asarray(a, dtype=np.float64)
- b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ M, a, b = list_to_array(M, a, b)
+ nx = get_backend(M, a, b)
dim_a, dim_b = M.shape
if len(a) == 0:
- a = np.ones(dim_a, dtype=np.float64) / dim_a
+ a = nx.ones(dim_a, type_as=M) / dim_a
if len(b) == 0:
- b = np.ones(dim_b, dtype=np.float64) / dim_b
+ b = nx.ones(dim_b, type_as=M) / dim_b
if len(b.shape) > 1:
n_hists = b.shape[1]
@@ -377,17 +377,14 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
# we assume that no distances are null except those of the diagonal of
# distances
if n_hists:
- u = np.ones((dim_a, 1)) / dim_a
- v = np.ones((dim_b, n_hists)) / dim_b
+ u = nx.ones((dim_a, 1), type_as=M) / dim_a
+ v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
a = a.reshape(dim_a, 1)
else:
- u = np.ones(dim_a) / dim_a
- v = np.ones(dim_b) / dim_b
+ u = nx.ones(dim_a, type_as=M) / dim_a
+ v = nx.ones(dim_b, type_as=M) / dim_b
- # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
- K = np.empty(M.shape, dtype=M.dtype)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
+ K = nx.exp(M / (-reg))
fi = reg_m / (reg_m + reg)
@@ -397,14 +394,14 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
uprev = u
vprev = v
- Kv = K.dot(v)
+ Kv = nx.dot(K, v)
u = (a / Kv) ** fi
- Ktu = K.T.dot(u)
+ Ktu = nx.dot(K.T, u)
v = (b / Ktu) ** fi
- if (np.any(Ktu == 0.)
- or np.any(np.isnan(u)) or np.any(np.isnan(v))
- or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ if (nx.any(Ktu == 0.)
+ or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
+ or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
warnings.warn('Numerical errors at iteration %s' % i)
@@ -412,8 +409,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
v = vprev
break
- err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.)
- err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.)
+ err_u = nx.max(nx.abs(u - uprev)) / max(
+ nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1.
+ )
+ err_v = nx.max(nx.abs(v - vprev)) / max(
+ nx.max(nx.abs(v)), nx.max(nx.abs(vprev)), 1.
+ )
err = 0.5 * (err_u + err_v)
if log:
log['err'].append(err)
@@ -426,11 +427,11 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
break
if log:
- log['logu'] = np.log(u + 1e-300)
- log['logv'] = np.log(v + 1e-300)
+ log['logu'] = nx.log(u + 1e-300)
+ log['logv'] = nx.log(v + 1e-300)
if n_hists: # return only loss
- res = np.einsum('ik,ij,jk,ij->k', u, K, v, M)
+ res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M)
if log:
return res, log
else:
@@ -475,12 +476,12 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000
Parameters
----------
- a : np.ndarray (dim_a,)
+ a : array-like (dim_a,)
Unnormalized histogram of dimension `dim_a`
- b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ b : array-like (dim_b,) or array-like (dim_b, n_hists)
One or multiple unnormalized histograms of dimension `dim_b`.
If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i`
- M : np.ndarray (dim_a, dim_b)
+ M : array-like (dim_a, dim_b)
loss matrix
reg : float
Entropy regularization term > 0
@@ -501,12 +502,12 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000
Returns
-------
if n_hists == 1:
- - gamma : (dim_a, dim_b) ndarray
+ - gamma : (dim_a, dim_b) array-like
Optimal transportation matrix for the given parameters
- log : dict
log dictionary returned only if `log` is `True`
else:
- - ot_distance : (n_hists,) ndarray
+ - ot_distance : (n_hists,) array-like
the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i`
- log : dict
log dictionary returned only if `log` is `True`
@@ -538,17 +539,15 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000
ot.optim.cg : General regularized OT
"""
-
- a = np.asarray(a, dtype=np.float64)
- b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ a, b, M = list_to_array(a, b, M)
+ nx = get_backend(M, a, b)
dim_a, dim_b = M.shape
if len(a) == 0:
- a = np.ones(dim_a, dtype=np.float64) / dim_a
+ a = nx.ones(dim_a, type_as=M) / dim_a
if len(b) == 0:
- b = np.ones(dim_b, dtype=np.float64) / dim_b
+ b = nx.ones(dim_b, type_as=M) / dim_b
if len(b.shape) > 1:
n_hists = b.shape[1]
@@ -561,56 +560,52 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000
# we assume that no distances are null except those of the diagonal of
# distances
if n_hists:
- u = np.ones((dim_a, n_hists)) / dim_a
- v = np.ones((dim_b, n_hists)) / dim_b
+ u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
+ v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
a = a.reshape(dim_a, 1)
else:
- u = np.ones(dim_a) / dim_a
- v = np.ones(dim_b) / dim_b
+ u = nx.ones(dim_a, type_as=M) / dim_a
+ v = nx.ones(dim_b, type_as=M) / dim_b
# print(reg)
- # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
- K = np.empty(M.shape, dtype=M.dtype)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
+ K = nx.exp(-M / reg)
fi = reg_m / (reg_m + reg)
cpt = 0
err = 1.
- alpha = np.zeros(dim_a)
- beta = np.zeros(dim_b)
+ alpha = nx.zeros(dim_a, type_as=M)
+ beta = nx.zeros(dim_b, type_as=M)
while (err > stopThr and cpt < numItermax):
uprev = u
vprev = v
- Kv = K.dot(v)
- f_alpha = np.exp(- alpha / (reg + reg_m))
- f_beta = np.exp(- beta / (reg + reg_m))
+ Kv = nx.dot(K, v)
+ f_alpha = nx.exp(- alpha / (reg + reg_m))
+ f_beta = nx.exp(- beta / (reg + reg_m))
if n_hists:
f_alpha = f_alpha[:, None]
f_beta = f_beta[:, None]
u = ((a / (Kv + 1e-16)) ** fi) * f_alpha
- Ktu = K.T.dot(u)
+ Ktu = nx.dot(K.T, u)
v = ((b / (Ktu + 1e-16)) ** fi) * f_beta
absorbing = False
- if (u > tau).any() or (v > tau).any():
+ if nx.any(u > tau) or nx.any(v > tau):
absorbing = True
if n_hists:
- alpha = alpha + reg * np.log(np.max(u, 1))
- beta = beta + reg * np.log(np.max(v, 1))
+ alpha = alpha + reg * nx.log(nx.max(u, 1))
+ beta = beta + reg * nx.log(nx.max(v, 1))
else:
- alpha = alpha + reg * np.log(np.max(u))
- beta = beta + reg * np.log(np.max(v))
- K = np.exp((alpha[:, None] + beta[None, :] -
- M) / reg)
- v = np.ones_like(v)
- Kv = K.dot(v)
-
- if (np.any(Ktu == 0.)
- or np.any(np.isnan(u)) or np.any(np.isnan(v))
- or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ alpha = alpha + reg * nx.log(nx.max(u))
+ beta = beta + reg * nx.log(nx.max(v))
+ K = nx.exp((alpha[:, None] + beta[None, :] - M) / reg)
+ v = nx.ones(v.shape, type_as=v)
+ Kv = nx.dot(K, v)
+
+ if (nx.any(Ktu == 0.)
+ or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
+ or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
warnings.warn('Numerical errors at iteration %s' % cpt)
@@ -620,8 +615,9 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000
if (cpt % 10 == 0 and not absorbing) or cpt == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- err = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(),
- 1.)
+ err = nx.max(nx.abs(u - uprev)) / max(
+ nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1.
+ )
if log:
log['err'].append(err)
if verbose:
@@ -636,25 +632,30 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000
"Try a larger entropy `reg` or a lower mass `reg_m`." +
"Or a larger absorption threshold `tau`.")
if n_hists:
- logu = alpha[:, None] / reg + np.log(u)
- logv = beta[:, None] / reg + np.log(v)
+ logu = alpha[:, None] / reg + nx.log(u)
+ logv = beta[:, None] / reg + nx.log(v)
else:
- logu = alpha / reg + np.log(u)
- logv = beta / reg + np.log(v)
+ logu = alpha / reg + nx.log(u)
+ logv = beta / reg + nx.log(v)
if log:
log['logu'] = logu
log['logv'] = logv
if n_hists: # return only loss
- res = logsumexp(np.log(M + 1e-100)[:, :, None] + logu[:, None, :] +
- logv[None, :, :] - M[:, :, None] / reg, axis=(0, 1))
- res = np.exp(res)
+ res = nx.logsumexp(
+ nx.log(M + 1e-100)[:, :, None]
+ + logu[:, None, :]
+ + logv[None, :, :]
+ - M[:, :, None] / reg,
+ axis=(0, 1)
+ )
+ res = nx.exp(res)
if log:
return res, log
else:
return res
else: # return OT matrix
- ot_matrix = np.exp(logu[:, None] + logv[None, :] - M / reg)
+ ot_matrix = nx.exp(logu[:, None] + logv[None, :] - M / reg)
if log:
return ot_matrix, log
else:
@@ -683,9 +684,9 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
Parameters
----------
- A : np.ndarray (dim, n_hists)
+ A : array-like (dim, n_hists)
`n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim`
- M : np.ndarray (dim, dim)
+ M : array-like (dim, dim)
ground metric matrix for OT.
reg : float
Entropy regularization term > 0
@@ -693,7 +694,7 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
Marginal relaxation term > 0
tau : float
Stabilization threshold for log domain absorption.
- weights : np.ndarray (n_hists,) optional
+ weights : array-like (n_hists,) optional
Weight of each distribution (barycentric coodinates)
If None, uniform weights are used.
numItermax : int, optional
@@ -708,7 +709,7 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
Returns
-------
- a : (dim,) ndarray
+ a : (dim,) array-like
Unbalanced Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
@@ -726,9 +727,12 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
"""
+ A, M = list_to_array(A, M)
+ nx = get_backend(A, M)
+
dim, n_hists = A.shape
if weights is None:
- weights = np.ones(n_hists) / n_hists
+ weights = nx.ones(n_hists, type_as=A) / n_hists
else:
assert(len(weights) == A.shape[1])
@@ -737,47 +741,43 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
fi = reg_m / (reg_m + reg)
- u = np.ones((dim, n_hists)) / dim
- v = np.ones((dim, n_hists)) / dim
+ u = nx.ones((dim, n_hists), type_as=A) / dim
+ v = nx.ones((dim, n_hists), type_as=A) / dim
# print(reg)
- # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
- K = np.empty(M.shape, dtype=M.dtype)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
+ K = nx.exp(-M / reg)
fi = reg_m / (reg_m + reg)
cpt = 0
err = 1.
- alpha = np.zeros(dim)
- beta = np.zeros(dim)
- q = np.ones(dim) / dim
+ alpha = nx.zeros(dim, type_as=A)
+ beta = nx.zeros(dim, type_as=A)
+ q = nx.ones(dim, type_as=A) / dim
for i in range(numItermax):
- qprev = q.copy()
- Kv = K.dot(v)
- f_alpha = np.exp(- alpha / (reg + reg_m))
- f_beta = np.exp(- beta / (reg + reg_m))
+ qprev = nx.copy(q)
+ Kv = nx.dot(K, v)
+ f_alpha = nx.exp(- alpha / (reg + reg_m))
+ f_beta = nx.exp(- beta / (reg + reg_m))
f_alpha = f_alpha[:, None]
f_beta = f_beta[:, None]
u = ((A / (Kv + 1e-16)) ** fi) * f_alpha
- Ktu = K.T.dot(u)
+ Ktu = nx.dot(K.T, u)
q = (Ktu ** (1 - fi)) * f_beta
- q = q.dot(weights) ** (1 / (1 - fi))
+ q = nx.dot(q, weights) ** (1 / (1 - fi))
Q = q[:, None]
v = ((Q / (Ktu + 1e-16)) ** fi) * f_beta
absorbing = False
- if (u > tau).any() or (v > tau).any():
+ if nx.any(u > tau) or nx.any(v > tau):
absorbing = True
- alpha = alpha + reg * np.log(np.max(u, 1))
- beta = beta + reg * np.log(np.max(v, 1))
- K = np.exp((alpha[:, None] + beta[None, :] -
- M) / reg)
- v = np.ones_like(v)
- Kv = K.dot(v)
- if (np.any(Ktu == 0.)
- or np.any(np.isnan(u)) or np.any(np.isnan(v))
- or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ alpha = alpha + reg * nx.log(nx.max(u, 1))
+ beta = beta + reg * nx.log(nx.max(v, 1))
+ K = nx.exp((alpha[:, None] + beta[None, :] - M) / reg)
+ v = nx.ones(v.shape, type_as=v)
+ Kv = nx.dot(K, v)
+ if (nx.any(Ktu == 0.)
+ or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
+ or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
warnings.warn('Numerical errors at iteration %s' % cpt)
@@ -786,8 +786,9 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
if (i % 10 == 0 and not absorbing) or i == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- err = abs(q - qprev).max() / max(abs(q).max(),
- abs(qprev).max(), 1.)
+ err = nx.max(nx.abs(q - qprev)) / max(
+ nx.max(nx.abs(q)), nx.max(nx.abs(qprev)), 1.
+ )
if log:
log['err'].append(err)
if verbose:
@@ -804,8 +805,8 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
"Or a larger absorption threshold `tau`.")
if log:
log['niter'] = i
- log['logu'] = np.log(u + 1e-300)
- log['logv'] = np.log(v + 1e-300)
+ log['logu'] = nx.log(u + 1e-300)
+ log['logv'] = nx.log(v + 1e-300)
return q, log
else:
return q
@@ -833,15 +834,15 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
Parameters
----------
- A : np.ndarray (dim, n_hists)
+ A : array-like (dim, n_hists)
`n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim`
- M : np.ndarray (dim, dim)
+ M : array-like (dim, dim)
ground metric matrix for OT.
reg : float
Entropy regularization term > 0
reg_m: float
Marginal relaxation term > 0
- weights : np.ndarray (n_hists,) optional
+ weights : array-like (n_hists,) optional
Weight of each distribution (barycentric coodinates)
If None, uniform weights are used.
numItermax : int, optional
@@ -856,7 +857,7 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
Returns
-------
- a : (dim,) ndarray
+ a : (dim,) array-like
Unbalanced Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
@@ -874,40 +875,43 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
"""
+ A, M = list_to_array(A, M)
+ nx = get_backend(A, M)
+
dim, n_hists = A.shape
if weights is None:
- weights = np.ones(n_hists) / n_hists
+ weights = nx.ones(n_hists, type_as=A) / n_hists
else:
assert(len(weights) == A.shape[1])
if log:
log = {'err': []}
- K = np.exp(- M / reg)
+ K = nx.exp(-M / reg)
fi = reg_m / (reg_m + reg)
- v = np.ones((dim, n_hists))
- u = np.ones((dim, 1))
- q = np.ones(dim)
+ v = nx.ones((dim, n_hists), type_as=A)
+ u = nx.ones((dim, 1), type_as=A)
+ q = nx.ones(dim, type_as=A)
err = 1.
for i in range(numItermax):
- uprev = u.copy()
- vprev = v.copy()
- qprev = q.copy()
+ uprev = nx.copy(u)
+ vprev = nx.copy(v)
+ qprev = nx.copy(q)
- Kv = K.dot(v)
+ Kv = nx.dot(K, v)
u = (A / Kv) ** fi
- Ktu = K.T.dot(u)
- q = ((Ktu ** (1 - fi)).dot(weights))
+ Ktu = nx.dot(K.T, u)
+ q = nx.dot(Ktu ** (1 - fi), weights)
q = q ** (1 / (1 - fi))
Q = q[:, None]
v = (Q / Ktu) ** fi
- if (np.any(Ktu == 0.)
- or np.any(np.isnan(u)) or np.any(np.isnan(v))
- or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ if (nx.any(Ktu == 0.)
+ or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
+ or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
warnings.warn('Numerical errors at iteration %s' % i)
@@ -916,8 +920,9 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
q = qprev
break
# compute change in barycenter
- err = abs(q - qprev).max()
- err /= max(abs(q).max(), abs(qprev).max(), 1.)
+ err = nx.max(nx.abs(q - qprev)) / max(
+ nx.max(nx.abs(q)), nx.max(nx.abs(qprev)), 1.0
+ )
if log:
log['err'].append(err)
# if barycenter did not change + at least 10 iterations - stop
@@ -932,8 +937,8 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
if log:
log['niter'] = i
- log['logu'] = np.log(u + 1e-300)
- log['logv'] = np.log(v + 1e-300)
+ log['logu'] = nx.log(u + 1e-300)
+ log['logv'] = nx.log(v + 1e-300)
return q, log
else:
return q
@@ -961,15 +966,15 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None,
Parameters
----------
- A : np.ndarray (dim, n_hists)
+ A : array-like (dim, n_hists)
`n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim`
- M : np.ndarray (dim, dim)
+ M : array-like (dim, dim)
ground metric matrix for OT.
reg : float
Entropy regularization term > 0
reg_m: float
Marginal relaxation term > 0
- weights : np.ndarray (n_hists,) optional
+ weights : array-like (n_hists,) optional
Weight of each distribution (barycentric coodinates)
If None, uniform weights are used.
numItermax : int, optional
@@ -984,7 +989,7 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None,
Returns
-------
- a : (dim,) ndarray
+ a : (dim,) array-like
Unbalanced Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
@@ -1025,3 +1030,225 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None,
log=log, **kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
+
+
+def mm_unbalanced(a, b, M, reg_m, div='kl', G0=None, numItermax=1000,
+ stopThr=1e-15, verbose=False, log=False):
+ r"""
+ Solve the unbalanced optimal transport problem and return the OT plan.
+ The function solves the following optimization problem:
+
+ .. math::
+ W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg_m} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) +
+ \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b})
+ s.t.
+ \gamma \geq 0
+
+ where:
+
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
+ unbalanced distributions
+ - div is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence
+
+ The algorithm used for solving the problem is a maximization-
+ minimization algorithm as proposed in :ref:`[41] <references-regpath>`
+
+ Parameters
+ ----------
+ a : array-like (dim_a,)
+ Unnormalized histogram of dimension `dim_a`
+ b : array-like (dim_b,)
+ Unnormalized histogram of dimension `dim_b`
+ M : array-like (dim_a, dim_b)
+ loss matrix
+ reg_m: float
+ Marginal relaxation term > 0
+ div: string, optional
+ Divergence to quantify the difference between the marginals.
+ Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic)
+ G0: array-like (dim_a, dim_b)
+ Initialization of the transport matrix
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on error (> 0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ Returns
+ -------
+ gamma : (dim_a, dim_b) array-like
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary returned only if `log` is `True`
+
+ Examples
+ --------
+ >>> import ot
+ >>> import numpy as np
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> M=[[1., 36.],[9., 4.]]
+ >>> np.round(ot.unbalanced.mm_unbalanced(a, b, M, 1, 'kl'), 2)
+ array([[0.3 , 0. ],
+ [0. , 0.07]])
+ >>> np.round(ot.unbalanced.mm_unbalanced(a, b, M, 1, 'l2'), 2)
+ array([[0.25, 0. ],
+ [0. , 0. ]])
+
+
+ .. _references-regpath:
+ References
+ ----------
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ Unbalanced optimal transport through non-negative penalized
+ linear regression. NeurIPS.
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.unbalanced.sinkhorn_unbalanced : Entropic regularized OT
+ """
+ M, a, b = list_to_array(M, a, b)
+ nx = get_backend(M, a, b)
+
+ dim_a, dim_b = M.shape
+
+ if len(a) == 0:
+ a = nx.ones(dim_a, type_as=M) / dim_a
+ if len(b) == 0:
+ b = nx.ones(dim_b, type_as=M) / dim_b
+
+ if G0 is None:
+ G = a[:, None] * b[None, :]
+ else:
+ G = G0
+
+ if log:
+ log = {'err': [], 'G': []}
+
+ if div == 'kl':
+ K = nx.exp(M / - reg_m / 2)
+ elif div == 'l2':
+ K = nx.maximum(a[:, None] + b[None, :] - M / reg_m / 2,
+ nx.zeros((dim_a, dim_b), type_as=M))
+ else:
+ warnings.warn("The div parameter should be either equal to 'kl' or \
+ 'l2': it has been set to 'kl'.")
+ div = 'kl'
+ K = nx.exp(M / - reg_m / 2)
+
+ for i in range(numItermax):
+ Gprev = G
+
+ if div == 'kl':
+ u = nx.sqrt(a / (nx.sum(G, 1) + 1e-16))
+ v = nx.sqrt(b / (nx.sum(G, 0) + 1e-16))
+ G = G * K * u[:, None] * v[None, :]
+ elif div == 'l2':
+ Gd = nx.sum(G, 0, keepdims=True) + nx.sum(G, 1, keepdims=True) + 1e-16
+ G = G * K / Gd
+
+ err = nx.sqrt(nx.sum((G - Gprev) ** 2))
+ if log:
+ log['err'].append(err)
+ log['G'].append(G)
+ if verbose:
+ print('{:5d}|{:8e}|'.format(i, err))
+ if err < stopThr:
+ break
+
+ if log:
+ log['cost'] = nx.sum(G * M)
+ return G, log
+ else:
+ return G
+
+
+def mm_unbalanced2(a, b, M, reg_m, div='kl', G0=None, numItermax=1000,
+ stopThr=1e-15, verbose=False, log=False):
+ r"""
+ Solve the unbalanced optimal transport problem and return the OT plan.
+ The function solves the following optimization problem:
+
+ .. math::
+ W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg_m} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) +
+ \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b})
+
+ s.t.
+ \gamma \geq 0
+
+ where:
+
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
+ unbalanced distributions
+ - :math:`\mathrm{div}` is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence
+
+ The algorithm used for solving the problem is a maximization-
+ minimization algorithm as proposed in :ref:`[41] <references-regpath>`
+
+ Parameters
+ ----------
+ a : array-like (dim_a,)
+ Unnormalized histogram of dimension `dim_a`
+ b : array-like (dim_b,)
+ Unnormalized histogram of dimension `dim_b`
+ M : array-like (dim_a, dim_b)
+ loss matrix
+ reg_m: float
+ Marginal relaxation term > 0
+ div: string, optional
+ Divergence to quantify the difference between the marginals.
+ Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic)
+ G0: array-like (dim_a, dim_b)
+ Initialization of the transport matrix
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on error (> 0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+ Returns
+ -------
+ ot_distance : array-like
+ the OT distance between :math:`\mathbf{a}` and :math:`\mathbf{b}`
+ log : dict
+ log dictionary returned only if `log` is `True`
+
+ Examples
+ --------
+ >>> import ot
+ >>> import numpy as np
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> M=[[1., 36.],[9., 4.]]
+ >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 1, 'l2'),2)
+ 0.25
+ >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 1, 'kl'),2)
+ 0.57
+
+ References
+ ----------
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ Unbalanced optimal transport through non-negative penalized
+ linear regression. NeurIPS.
+ See Also
+ --------
+ ot.lp.emd2 : Unregularized OT loss
+ ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss
+ """
+ _, log_mm = mm_unbalanced(a, b, M, reg_m, div=div, G0=G0,
+ numItermax=numItermax, stopThr=stopThr,
+ verbose=verbose, log=True)
+
+ if log:
+ return log_mm['cost'], log_mm
+ else:
+ return log_mm['cost']
diff --git a/ot/utils.py b/ot/utils.py
index e6c93c8..a23ce7e 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -15,7 +15,7 @@ from scipy.spatial.distance import cdist
import sys
import warnings
from inspect import signature
-from .backend import get_backend
+from .backend import get_backend, Backend
__time_tic_toc = time.time()
@@ -51,7 +51,8 @@ def kernel(x1, x2, method='gaussian', sigma=1, **kwargs):
def laplacian(x):
r"""Compute Laplacian matrix"""
- L = np.diag(np.sum(x, axis=0)) - x
+ nx = get_backend(x)
+ L = nx.diag(nx.sum(x, axis=0)) - x
return L
@@ -116,7 +117,7 @@ def proj_simplex(v, z=1):
return w
-def unif(n):
+def unif(n, type_as=None):
r"""
Return a uniform histogram of length `n` (simplex).
@@ -124,13 +125,19 @@ def unif(n):
----------
n : int
number of bins in the histogram
+ type_as : array_like
+ array of the same type of the expected output (numpy/pytorch/jax)
Returns
-------
- h : np.array (`n`,)
+ h : array_like (`n`,)
histogram of length `n` such that :math:`\forall i, \mathbf{h}_i = \frac{1}{n}`
"""
- return np.ones((n,)) / n
+ if type_as is None:
+ return np.ones((n,)) / n
+ else:
+ nx = get_backend(type_as)
+ return nx.ones((n,), type_as=type_as) / n
def clean_zeros(a, b, M):
@@ -290,7 +297,8 @@ def cost_normalization(C, norm=None):
def dots(*args):
r""" dots function for multiple matrix multiply """
- return reduce(np.dot, args)
+ nx = get_backend(*args)
+ return reduce(nx.dot, args)
def label_normalization(y, start=0):
@@ -308,8 +316,9 @@ def label_normalization(y, start=0):
y : array-like, shape (`n1`, )
The input vector of labels normalized according to given start value.
"""
+ nx = get_backend(y)
- diff = np.min(np.unique(y)) - start
+ diff = nx.min(nx.unique(y)) - start
if diff != 0:
y -= diff
return y
@@ -476,6 +485,19 @@ class BaseEstimator(object):
arguments (no ``*args`` or ``**kwargs``).
"""
+ nx: Backend = None
+
+ def _get_backend(self, *arrays):
+ nx = get_backend(
+ *[input_ for input_ in arrays if input_ is not None]
+ )
+ if nx.__name__ in ("jax", "tf"):
+ raise TypeError(
+ """JAX or TF arrays have been received but domain
+ adaptation does not support those backend.""")
+ self.nx = nx
+ return nx
+
@classmethod
def _get_param_names(cls):
r"""Get parameter names for the estimator"""
diff --git a/ot/weak.py b/ot/weak.py
new file mode 100644
index 0000000..f7d5b23
--- /dev/null
+++ b/ot/weak.py
@@ -0,0 +1,124 @@
+"""
+Weak optimal ransport solvers
+"""
+
+# Author: Remi Flamary <remi.flamary@polytehnique.edu>
+#
+# License: MIT License
+
+from .backend import get_backend
+from .optim import cg
+import numpy as np
+
+__all__ = ['weak_optimal_transport']
+
+
+def weak_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, G0=None, **kwargs):
+ r"""Solves the weak optimal transport problem between two empirical distributions
+
+
+ .. math::
+ \gamma = \mathop{\arg \min}_\gamma \quad \|X_a-diag(1/a)\gammaX_b\|_F^2
+
+ s.t. \ \gamma \mathbf{1} = \mathbf{a}
+
+ \gamma^T \mathbf{1} = \mathbf{b}
+
+ \gamma \geq 0
+
+ where :
+
+ - :math:`X_a` :math:`X_b` are the sample matrices.
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
+
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
+ Uses the conditional gradient algorithm to solve the problem proposed
+ in :ref:`[39] <references-weak>`.
+
+ Parameters
+ ----------
+ Xa : (ns,d) array-like, float
+ Source samples
+ Xb : (nt,d) array-like, float
+ Target samples
+ a : (ns,) array-like, float
+ Source histogram (uniform weight if empty list)
+ b : (nt,) array-like, float
+ Target histogram (uniform weight if empty list))
+ numItermax : int, optional
+ Max number of iterations
+ numItermaxEmd : int, optional
+ Max number of iterations for emd
+ stopThr : float, optional
+ Stop threshold on the relative variation (>0)
+ stopThr2 : float, optional
+ Stop threshold on the absolute variation (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma: array-like, shape (ns, nt)
+ Optimal transportation matrix for the given
+ parameters
+ log: dict, optional
+ If input log is true, a dictionary containing the
+ cost and dual variables and exit status
+
+
+ .. _references-weak:
+ References
+ ----------
+ .. [39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017).
+ Kantorovich duality for general transport costs and applications.
+ Journal of Functional Analysis, 273(11), 3327-3405.
+
+ See Also
+ --------
+ ot.bregman.sinkhorn : Entropic regularized OT
+ ot.optim.cg : General regularized OT
+ """
+
+ nx = get_backend(Xa, Xb)
+
+ Xa2 = nx.to_numpy(Xa)
+ Xb2 = nx.to_numpy(Xb)
+
+ if a is None:
+ a2 = np.ones((Xa.shape[0])) / Xa.shape[0]
+ else:
+ a2 = nx.to_numpy(a)
+ if b is None:
+ b2 = np.ones((Xb.shape[0])) / Xb.shape[0]
+ else:
+ b2 = nx.to_numpy(b)
+
+ # init uniform
+ if G0 is None:
+ T0 = a2[:, None] * b2[None, :]
+ else:
+ T0 = nx.to_numpy(G0)
+
+ # weak OT loss
+ def f(T):
+ return np.dot(a2, np.sum((Xa2 - np.dot(T, Xb2) / a2[:, None])**2, 1))
+
+ # weak OT gradient
+ def df(T):
+ return -2 * np.dot(Xa2 - np.dot(T, Xb2) / a2[:, None], Xb2.T)
+
+ # solve with conditional gradient and return solution
+ if log:
+ res, log = cg(a2, b2, 0, 1, f, df, T0, log=log, verbose=verbose, **kwargs)
+ log['u'] = nx.from_numpy(log['u'], type_as=Xa)
+ log['v'] = nx.from_numpy(log['v'], type_as=Xb)
+ return nx.from_numpy(res, type_as=Xa), log
+ else:
+ return nx.from_numpy(cg(a2, b2, 0, 1, f, df, T0, log=log, verbose=verbose, **kwargs), type_as=Xa)