diff options
author | Gard Spreemann <gspr@nonempty.org> | 2022-04-27 11:49:23 +0200 |
---|---|---|
committer | Gard Spreemann <gspr@nonempty.org> | 2022-04-27 11:49:23 +0200 |
commit | 35bd2c98b642df78638d7d733bc1a89d873db1de (patch) | |
tree | 6bc637624004713808d3097b95acdccbb9608e52 /ot | |
parent | c4753bd3f74139af8380127b66b484bc09b50661 (diff) | |
parent | eccb1386eea52b94b82456d126bd20cbe3198e05 (diff) |
Merge tag '0.8.2' into dfsg/latest
Diffstat (limited to 'ot')
-rw-r--r-- | ot/__init__.py | 14 | ||||
-rw-r--r-- | ot/backend.py | 306 | ||||
-rw-r--r-- | ot/bregman.py | 17 | ||||
-rw-r--r-- | ot/da.py | 382 | ||||
-rw-r--r-- | ot/dr.py | 44 | ||||
-rw-r--r-- | ot/factored.py | 145 | ||||
-rw-r--r-- | ot/gpu/__init__.py | 50 | ||||
-rw-r--r-- | ot/gpu/bregman.py | 196 | ||||
-rw-r--r-- | ot/gpu/da.py | 144 | ||||
-rw-r--r-- | ot/gpu/utils.py | 101 | ||||
-rw-r--r-- | ot/gromov.py | 1109 | ||||
-rw-r--r-- | ot/lp/__init__.py | 123 | ||||
-rw-r--r-- | ot/lp/cvx.py | 1 | ||||
-rw-r--r-- | ot/optim.py | 11 | ||||
-rwxr-xr-x | ot/partial.py | 84 | ||||
-rw-r--r-- | ot/plot.py | 7 | ||||
-rw-r--r-- | ot/regpath.py | 545 | ||||
-rw-r--r-- | ot/stochastic.py | 242 | ||||
-rw-r--r-- | ot/unbalanced.py | 525 | ||||
-rw-r--r-- | ot/utils.py | 36 | ||||
-rw-r--r-- | ot/weak.py | 124 |
21 files changed, 3053 insertions, 1153 deletions
diff --git a/ot/__init__.py b/ot/__init__.py index f55819d..86ed94e 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -1,5 +1,4 @@ """ - .. warning:: The list of automatically imported sub-modules is as follows: :py:mod:`ot.lp`, :py:mod:`ot.bregman`, :py:mod:`ot.optim` @@ -7,13 +6,10 @@ :py:mod:`ot.gromov`, :py:mod:`ot.smooth` :py:mod:`ot.stochastic`, :py:mod:`ot.partial`, :py:mod:`ot.regpath` , :py:mod:`ot.unbalanced`. - The following sub-modules are not imported due to additional dependencies: - - :any:`ot.dr` : depends on :code:`pymanopt` and :code:`autograd`. - :any:`ot.gpu` : depends on :code:`cupy` and a CUDA GPU. - :any:`ot.plot` : depends on :code:`matplotlib` - """ # Author: Remi Flamary <remi.flamary@unice.fr> @@ -36,6 +32,8 @@ from . import unbalanced from . import partial from . import backend from . import regpath +from . import weak +from . import factored # OT functions from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d @@ -46,11 +44,14 @@ from .da import sinkhorn_lpl1_mm from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance from .gromov import (gromov_wasserstein, gromov_wasserstein2, gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) +from .weak import weak_optimal_transport +from .factored import factored_optimal_transport + # utils functions from .utils import dist, unif, tic, toc, toq -__version__ = "0.8.1.0" +__version__ = "0.8.2" __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', @@ -59,5 +60,6 @@ __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'sinkhorn_unbalanced', 'barycenter_unbalanced', 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2', - 'max_sliced_wasserstein_distance', + 'max_sliced_wasserstein_distance', 'weak_optimal_transport', + 'factored_optimal_transport', 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath'] diff --git a/ot/backend.py b/ot/backend.py index 58b652b..361ffba 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -87,7 +87,9 @@ Performance # License: MIT License import numpy as np -import scipy.special as scipy +import scipy +import scipy.linalg +import scipy.special as special from scipy.sparse import issparse, coo_matrix, csr_matrix import warnings import time @@ -102,7 +104,7 @@ except ImportError: try: import jax import jax.numpy as jnp - import jax.scipy.special as jscipy + import jax.scipy.special as jspecial from jax.lib import xla_bridge jax_type = jax.numpy.ndarray except ImportError: @@ -202,13 +204,29 @@ class Backend(): def __str__(self): return self.__name__ - # convert to numpy - def to_numpy(self, a): + # convert batch of tensors to numpy + def to_numpy(self, *arrays): + """Returns the numpy version of tensors""" + if len(arrays) == 1: + return self._to_numpy(arrays[0]) + else: + return [self._to_numpy(array) for array in arrays] + + # convert a tensor to numpy + def _to_numpy(self, a): """Returns the numpy version of a tensor""" raise NotImplementedError() - # convert from numpy - def from_numpy(self, a, type_as=None): + # convert batch of arrays from numpy + def from_numpy(self, *arrays, type_as=None): + """Creates tensors cloning a numpy array, with the given precision (defaulting to input's precision) and the given device (in case of GPUs)""" + if len(arrays) == 1: + return self._from_numpy(arrays[0], type_as=type_as) + else: + return [self._from_numpy(array, type_as=type_as) for array in arrays] + + # convert an array from numpy + def _from_numpy(self, a, type_as=None): """Creates a tensor cloning a numpy array, with the given precision (defaulting to input's precision) and the given device (in case of GPUs)""" raise NotImplementedError() @@ -536,6 +554,16 @@ class Backend(): """ raise NotImplementedError() + def argmin(self, a, axis=None): + r""" + Returns the indices of the minimum values of a tensor along given dimensions. + + This function follows the api from :any:`numpy.argmin` + + See: https://numpy.org/doc/stable/reference/generated/numpy.argmin.html + """ + raise NotImplementedError() + def mean(self, a, axis=None): r""" Computes the arithmetic mean of a tensor along given dimensions. @@ -786,6 +814,72 @@ class Backend(): """ raise NotImplementedError() + def solve(self, a, b): + r""" + Solves a linear matrix equation, or system of linear scalar equations. + + This function follows the api from :any:`numpy.linalg.solve`. + + See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.solve.html + """ + raise NotImplementedError() + + def trace(self, a): + r""" + Returns the sum along diagonals of the array. + + This function follows the api from :any:`numpy.trace`. + + See: https://numpy.org/doc/stable/reference/generated/numpy.trace.html + """ + raise NotImplementedError() + + def inv(self, a): + r""" + Computes the inverse of a matrix. + + This function follows the api from :any:`scipy.linalg.inv`. + + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.inv.html + """ + raise NotImplementedError() + + def sqrtm(self, a): + r""" + Computes the matrix square root. Requires input to be definite positive. + + This function follows the api from :any:`scipy.linalg.sqrtm`. + + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.sqrtm.html + """ + raise NotImplementedError() + + def isfinite(self, a): + r""" + Tests element-wise for finiteness (not infinity and not Not a Number). + + This function follows the api from :any:`numpy.isfinite`. + + See: https://numpy.org/doc/stable/reference/generated/numpy.isfinite.html + """ + raise NotImplementedError() + + def array_equal(self, a, b): + r""" + True if two arrays have the same shape and elements, False otherwise. + + This function follows the api from :any:`numpy.array_equal`. + + See: https://numpy.org/doc/stable/reference/generated/numpy.array_equal.html + """ + raise NotImplementedError() + + def is_floating_point(self, a): + r""" + Returns whether or not the input consists of floats + """ + raise NotImplementedError() + class NumpyBackend(Backend): """ @@ -802,10 +896,10 @@ class NumpyBackend(Backend): rng_ = np.random.RandomState() - def to_numpy(self, a): + def _to_numpy(self, a): return a - def from_numpy(self, a, type_as=None): + def _from_numpy(self, a, type_as=None): if type_as is None: return a elif isinstance(a, float): @@ -936,6 +1030,9 @@ class NumpyBackend(Backend): def argmax(self, a, axis=None): return np.argmax(a, axis=axis) + def argmin(self, a, axis=None): + return np.argmin(a, axis=axis) + def mean(self, a, axis=None): return np.mean(a, axis=axis) @@ -955,7 +1052,7 @@ class NumpyBackend(Backend): return np.unique(a) def logsumexp(self, a, axis=None): - return scipy.logsumexp(a, axis=axis) + return special.logsumexp(a, axis=axis) def stack(self, arrays, axis=0): return np.stack(arrays, axis) @@ -1004,8 +1101,11 @@ class NumpyBackend(Backend): else: return a - def where(self, condition, x, y): - return np.where(condition, x, y) + def where(self, condition, x=None, y=None): + if x is None and y is None: + return np.where(condition) + else: + return np.where(condition, x, y) def copy(self, a): return a.copy() @@ -1046,6 +1146,27 @@ class NumpyBackend(Backend): results[key] = (t1 - t0) / n_runs return results + def solve(self, a, b): + return np.linalg.solve(a, b) + + def trace(self, a): + return np.trace(a) + + def inv(self, a): + return scipy.linalg.inv(a) + + def sqrtm(self, a): + return scipy.linalg.sqrtm(a) + + def isfinite(self, a): + return np.isfinite(a) + + def array_equal(self, a, b): + return np.array_equal(a, b) + + def is_floating_point(self, a): + return a.dtype.kind == "f" + class JaxBackend(Backend): """ @@ -1075,13 +1196,15 @@ class JaxBackend(Backend): jax.device_put(jnp.array(1, dtype=jnp.float64), d) ] - def to_numpy(self, a): + def _to_numpy(self, a): return np.array(a) def _change_device(self, a, type_as): return jax.device_put(a, type_as.device_buffer.device()) - def from_numpy(self, a, type_as=None): + def _from_numpy(self, a, type_as=None): + if isinstance(a, float): + a = np.array(a) if type_as is None: return jnp.array(a) else: @@ -1216,6 +1339,9 @@ class JaxBackend(Backend): def argmax(self, a, axis=None): return jnp.argmax(a, axis=axis) + def argmin(self, a, axis=None): + return jnp.argmin(a, axis=axis) + def mean(self, a, axis=None): return jnp.mean(a, axis=axis) @@ -1235,7 +1361,7 @@ class JaxBackend(Backend): return jnp.unique(a) def logsumexp(self, a, axis=None): - return jscipy.logsumexp(a, axis=axis) + return jspecial.logsumexp(a, axis=axis) def stack(self, arrays, axis=0): return jnp.stack(arrays, axis) @@ -1293,8 +1419,11 @@ class JaxBackend(Backend): # Currently, JAX does not support sparse matrices return a - def where(self, condition, x, y): - return jnp.where(condition, x, y) + def where(self, condition, x=None, y=None): + if x is None and y is None: + return jnp.where(condition) + else: + return jnp.where(condition, x, y) def copy(self, a): # No need to copy, JAX arrays are immutable @@ -1339,6 +1468,28 @@ class JaxBackend(Backend): results[key] = (t1 - t0) / n_runs return results + def solve(self, a, b): + return jnp.linalg.solve(a, b) + + def trace(self, a): + return jnp.trace(a) + + def inv(self, a): + return jnp.linalg.inv(a) + + def sqrtm(self, a): + L, V = jnp.linalg.eigh(a) + return (V * jnp.sqrt(L)[None, :]) @ V.T + + def isfinite(self, a): + return jnp.isfinite(a) + + def array_equal(self, a, b): + return jnp.array_equal(a, b) + + def is_floating_point(self, a): + return a.dtype.kind == "f" + class TorchBackend(Backend): """ @@ -1384,10 +1535,10 @@ class TorchBackend(Backend): self.ValFunction = ValFunction - def to_numpy(self, a): + def _to_numpy(self, a): return a.cpu().detach().numpy() - def from_numpy(self, a, type_as=None): + def _from_numpy(self, a, type_as=None): if isinstance(a, float): a = np.array(a) if type_as is None: @@ -1397,7 +1548,7 @@ class TorchBackend(Backend): def set_gradients(self, val, inputs, grads): - Func = self.ValFunction() + Func = self.ValFunction res = Func.apply(val, grads, *inputs) @@ -1564,6 +1715,9 @@ class TorchBackend(Backend): def argmax(self, a, axis=None): return torch.argmax(a, dim=axis) + def argmin(self, a, axis=None): + return torch.argmin(a, dim=axis) + def mean(self, a, axis=None): if axis is not None: return torch.mean(a, dim=axis) @@ -1580,8 +1734,11 @@ class TorchBackend(Backend): return torch.linspace(start, stop, num, dtype=torch.float64) def meshgrid(self, a, b): - X, Y = torch.meshgrid(a, b) - return X.T, Y.T + try: + return torch.meshgrid(a, b, indexing="xy") + except TypeError: + X, Y = torch.meshgrid(a, b) + return X.T, Y.T def diag(self, a, k=0): return torch.diag(a, diagonal=k) @@ -1659,8 +1816,11 @@ class TorchBackend(Backend): else: return a - def where(self, condition, x, y): - return torch.where(condition, x, y) + def where(self, condition, x=None, y=None): + if x is None and y is None: + return torch.where(condition) + else: + return torch.where(condition, x, y) def copy(self, a): return torch.clone(a) @@ -1718,6 +1878,28 @@ class TorchBackend(Backend): torch.cuda.empty_cache() return results + def solve(self, a, b): + return torch.linalg.solve(a, b) + + def trace(self, a): + return torch.trace(a) + + def inv(self, a): + return torch.linalg.inv(a) + + def sqrtm(self, a): + L, V = torch.linalg.eigh(a) + return (V * torch.sqrt(L)[None, :]) @ V.T + + def isfinite(self, a): + return torch.isfinite(a) + + def array_equal(self, a, b): + return torch.equal(a, b) + + def is_floating_point(self, a): + return a.dtype.is_floating_point + class CupyBackend(Backend): # pragma: no cover """ @@ -1741,10 +1923,12 @@ class CupyBackend(Backend): # pragma: no cover cp.array(1, dtype=cp.float64) ] - def to_numpy(self, a): + def _to_numpy(self, a): return cp.asnumpy(a) - def from_numpy(self, a, type_as=None): + def _from_numpy(self, a, type_as=None): + if isinstance(a, float): + a = np.array(a) if type_as is None: return cp.asarray(a) else: @@ -1884,6 +2068,9 @@ class CupyBackend(Backend): # pragma: no cover def argmax(self, a, axis=None): return cp.argmax(a, axis=axis) + def argmin(self, a, axis=None): + return cp.argmin(a, axis=axis) + def mean(self, a, axis=None): return cp.mean(a, axis=axis) @@ -1982,8 +2169,11 @@ class CupyBackend(Backend): # pragma: no cover else: return a - def where(self, condition, x, y): - return cp.where(condition, x, y) + def where(self, condition, x=None, y=None): + if x is None and y is None: + return cp.where(condition) + else: + return cp.where(condition, x, y) def copy(self, a): return a.copy() @@ -2035,6 +2225,28 @@ class CupyBackend(Backend): # pragma: no cover pinned_mempool.free_all_blocks() return results + def solve(self, a, b): + return cp.linalg.solve(a, b) + + def trace(self, a): + return cp.trace(a) + + def inv(self, a): + return cp.linalg.inv(a) + + def sqrtm(self, a): + L, V = cp.linalg.eigh(a) + return (V * self.sqrt(L)[None, :]) @ V.T + + def isfinite(self, a): + return cp.isfinite(a) + + def array_equal(self, a, b): + return cp.array_equal(a, b) + + def is_floating_point(self, a): + return a.dtype.kind == "f" + class TensorflowBackend(Backend): @@ -2060,13 +2272,16 @@ class TensorflowBackend(Backend): "To use TensorflowBackend, you need to activate the tensorflow " "numpy API. You can activate it by running: \n" "from tensorflow.python.ops.numpy_ops import np_config\n" - "np_config.enable_numpy_behavior()" + "np_config.enable_numpy_behavior()", + stacklevel=2 ) - def to_numpy(self, a): + def _to_numpy(self, a): return a.numpy() - def from_numpy(self, a, type_as=None): + def _from_numpy(self, a, type_as=None): + if isinstance(a, float): + a = np.array(a) if not isinstance(a, self.__type__): if type_as is None: return tf.convert_to_tensor(a) @@ -2208,6 +2423,9 @@ class TensorflowBackend(Backend): def argmax(self, a, axis=None): return tnp.argmax(a, axis=axis) + def argmin(self, a, axis=None): + return tnp.argmin(a, axis=axis) + def mean(self, a, axis=None): return tnp.mean(a, axis=axis) @@ -2309,8 +2527,11 @@ class TensorflowBackend(Backend): else: return a - def where(self, condition, x, y): - return tnp.where(condition, x, y) + def where(self, condition, x=None, y=None): + if x is None and y is None: + return tnp.where(condition) + else: + return tnp.where(condition, x, y) def copy(self, a): return tf.identity(a) @@ -2364,3 +2585,24 @@ class TensorflowBackend(Backend): results[key] = (t1 - t0) / n_runs return results + + def solve(self, a, b): + return tf.linalg.solve(a, b) + + def trace(self, a): + return tf.linalg.trace(a) + + def inv(self, a): + return tf.linalg.inv(a) + + def sqrtm(self, a): + return tf.linalg.sqrtm(a) + + def isfinite(self, a): + return tnp.isfinite(a) + + def array_equal(self, a, b): + return tnp.array_equal(a, b) + + def is_floating_point(self, a): + return a.dtype.is_floating diff --git a/ot/bregman.py b/ot/bregman.py index fc20175..c06af2f 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -2525,8 +2525,7 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, # geometric interpolation delta = nx.exp(alpha * nx.log(other) + (1 - alpha) * nx.log(inv_new)) K = projR(K, delta) - K0 = nx.dot(nx.diag(nx.dot(D.T, delta / inv_new)), K0) - + K0 = nx.dot(D.T, delta / inv_new)[:, None] * K0 err = nx.norm(nx.sum(K0, axis=1) - old) old = new if log: @@ -2656,16 +2655,16 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, classes = nx.unique(Ys[d]) # build the corresponding D_1 and D_2 matrices - Dtmp1 = nx.zeros((nbclasses, nsk), type_as=Xs[0]) - Dtmp2 = nx.zeros((nbclasses, nsk), type_as=Xs[0]) + Dtmp1 = np.zeros((nbclasses, nsk)) + Dtmp2 = np.zeros((nbclasses, nsk)) for c in classes: - nbelemperclass = nx.sum(Ys[d] == c) + nbelemperclass = float(nx.sum(Ys[d] == c)) if nbelemperclass != 0: - Dtmp1[int(c), Ys[d] == c] = 1. - Dtmp2[int(c), Ys[d] == c] = 1. / (nbelemperclass) - D1.append(Dtmp1) - D2.append(Dtmp2) + Dtmp1[int(c), nx.to_numpy(Ys[d] == c)] = 1. + Dtmp2[int(c), nx.to_numpy(Ys[d] == c)] = 1. / (nbelemperclass) + D1.append(nx.from_numpy(Dtmp1, type_as=Xs[0])) + D2.append(nx.from_numpy(Dtmp2, type_as=Xs[0])) # build the cost matrix and the Gibbs kernel Mtmp = dist(Xs[d], Xt, metric=metric) @@ -12,12 +12,12 @@ Domain adaptation with optimal transport # License: MIT License import numpy as np -import scipy.linalg as linalg +from .backend import get_backend from .bregman import sinkhorn, jcpot_barycenter from .lp import emd from .utils import unif, dist, kernel, cost_normalization, label_normalization, laplacian, dots -from .utils import check_params, BaseEstimator +from .utils import list_to_array, check_params, BaseEstimator from .unbalanced import sinkhorn_unbalanced from .optim import cg from .optim import gcg @@ -60,13 +60,13 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, Parameters ---------- - a : np.ndarray (ns,) + a : array-like (ns,) samples weights in the source domain - labels_a : np.ndarray (ns,) + labels_a : array-like (ns,) labels of samples in the source domain - b : np.ndarray (nt,) + b : array-like (nt,) samples weights in the target domain - M : np.ndarray (ns,nt) + M : array-like (ns,nt) loss matrix reg : float Regularization term for entropic regularization >0 @@ -86,7 +86,7 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, Returns ------- - gamma : (ns, nt) ndarray + gamma : (ns, nt) array-like Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -111,26 +111,28 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, ot.optim.cg : General regularized OT """ + a, labels_a, b, M = list_to_array(a, labels_a, b, M) + nx = get_backend(a, labels_a, b, M) + p = 0.5 epsilon = 1e-3 indices_labels = [] - classes = np.unique(labels_a) + classes = nx.unique(labels_a) for c in classes: - idxc, = np.where(labels_a == c) + idxc, = nx.where(labels_a == c) indices_labels.append(idxc) - W = np.zeros(M.shape) - + W = nx.zeros(M.shape, type_as=M) for cpt in range(numItermax): Mreg = M + eta * W transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax, stopThr=stopInnerThr) # the transport has been computed. Check if classes are really # separated - W = np.ones(M.shape) + W = nx.ones(M.shape, type_as=M) for (i, c) in enumerate(classes): - majs = np.sum(transp[indices_labels[i]], axis=0) + majs = nx.sum(transp[indices_labels[i]], axis=0) majs = p * ((majs + epsilon) ** (p - 1)) W[indices_labels[i]] = majs @@ -174,13 +176,13 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, Parameters ---------- - a : np.ndarray (ns,) + a : array-like (ns,) samples weights in the source domain - labels_a : np.ndarray (ns,) + labels_a : array-like (ns,) labels of samples in the source domain - b : np.ndarray (nt,) + b : array-like (nt,) samples in the target domain - M : np.ndarray (ns,nt) + M : array-like (ns,nt) loss matrix reg : float Regularization term for entropic regularization >0 @@ -200,7 +202,7 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, Returns ------- - gamma : (ns, nt) ndarray + gamma : (ns, nt) array-like Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -222,22 +224,25 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, ot.optim.gcg : Generalized conditional gradient for OT problems """ - lstlab = np.unique(labels_a) + a, labels_a, b, M = list_to_array(a, labels_a, b, M) + nx = get_backend(a, labels_a, b, M) + + lstlab = nx.unique(labels_a) def f(G): res = 0 for i in range(G.shape[1]): for lab in lstlab: temp = G[labels_a == lab, i] - res += np.linalg.norm(temp) + res += nx.norm(temp) return res def df(G): - W = np.zeros(G.shape) + W = nx.zeros(G.shape, type_as=G) for i in range(G.shape[1]): for lab in lstlab: temp = G[labels_a == lab, i] - n = np.linalg.norm(temp) + n = nx.norm(temp) if n: W[labels_a == lab, i] = temp / n return W @@ -289,9 +294,9 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, Parameters ---------- - xs : np.ndarray (ns,d) + xs : array-like (ns,d) samples in the source domain - xt : np.ndarray (nt,d) + xt : array-like (nt,d) samples in the target domain mu : float,optional Weight for the linear OT loss (>0) @@ -315,9 +320,9 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, Returns ------- - gamma : (ns, nt) ndarray + gamma : (ns, nt) array-like Optimal transportation matrix for the given parameters - L : (d, d) ndarray + L : (d, d) array-like Linear mapping matrix ((:math:`d+1`, `d`) if bias) log : dict log dictionary return only if log==True in parameters @@ -336,13 +341,15 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, ot.optim.cg : General regularized OT """ + xs, xt = list_to_array(xs, xt) + nx = get_backend(xs, xt) ns, nt, d = xs.shape[0], xt.shape[0], xt.shape[1] if bias: - xs1 = np.hstack((xs, np.ones((ns, 1)))) - xstxs = xs1.T.dot(xs1) - Id = np.eye(d + 1) + xs1 = nx.concatenate((xs, nx.ones((ns, 1), type_as=xs)), axis=1) + xstxs = nx.dot(xs1.T, xs1) + Id = nx.eye(d + 1, type_as=xs) Id[-1] = 0 I0 = Id[:, :-1] @@ -350,8 +357,8 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, return x[:-1, :] else: xs1 = xs - xstxs = xs1.T.dot(xs1) - Id = np.eye(d) + xstxs = nx.dot(xs1.T, xs1) + Id = nx.eye(d, type_as=xs) I0 = Id def sel(x): @@ -360,7 +367,8 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, if log: log = {'err': []} - a, b = unif(ns), unif(nt) + a = unif(ns, type_as=xs) + b = unif(nt, type_as=xt) M = dist(xs, xt) * ns G = emd(a, b, M) @@ -368,23 +376,26 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, def loss(L, G): """Compute full loss""" - return np.sum((xs1.dot(L) - ns * G.dot(xt)) ** 2) + mu * \ - np.sum(G * M) + eta * np.sum(sel(L - I0) ** 2) + return ( + nx.sum((nx.dot(xs1, L) - ns * nx.dot(G, xt)) ** 2) + + mu * nx.sum(G * M) + + eta * nx.sum(sel(L - I0) ** 2) + ) def solve_L(G): """ solve L problem with fixed G (least square)""" - xst = ns * G.dot(xt) - return np.linalg.solve(xstxs + eta * Id, xs1.T.dot(xst) + eta * I0) + xst = ns * nx.dot(G, xt) + return nx.solve(xstxs + eta * Id, nx.dot(xs1.T, xst) + eta * I0) def solve_G(L, G0): """Update G with CG algorithm""" - xsi = xs1.dot(L) + xsi = nx.dot(xs1, L) def f(G): - return np.sum((xsi - ns * G.dot(xt)) ** 2) + return nx.sum((xsi - ns * nx.dot(G, xt)) ** 2) def df(G): - return -2 * ns * (xsi - ns * G.dot(xt)).dot(xt.T) + return -2 * ns * nx.dot(xsi - ns * nx.dot(G, xt), xt.T) G = cg(a, b, M, 1.0 / mu, f, df, G0=G0, numItermax=numInnerItermax, stopThr=stopInnerThr) @@ -481,9 +492,9 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', Parameters ---------- - xs : np.ndarray (ns,d) + xs : array-like (ns,d) samples in the source domain - xt : np.ndarray (nt,d) + xt : array-like (nt,d) samples in the target domain mu : float,optional Weight for the linear OT loss (>0) @@ -513,9 +524,9 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', Returns ------- - gamma : (ns, nt) ndarray + gamma : (ns, nt) array-like Optimal transportation matrix for the given parameters - L : (ns, d) ndarray + L : (ns, d) array-like Nonlinear mapping matrix ((:math:`n_s+1`, `d`) if bias) log : dict log dictionary return only if log==True in parameters @@ -534,15 +545,17 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', ot.optim.cg : General regularized OT """ + xs, xt = list_to_array(xs, xt) + nx = get_backend(xs, xt) ns, nt = xs.shape[0], xt.shape[0] K = kernel(xs, xs, method=kerneltype, sigma=sigma) if bias: - K1 = np.hstack((K, np.ones((ns, 1)))) - Id = np.eye(ns + 1) + K1 = nx.concatenate((K, nx.ones((ns, 1), type_as=xs)), axis=1) + Id = nx.eye(ns + 1, type_as=xs) Id[-1] = 0 - Kp = np.eye(ns + 1) + Kp = nx.eye(ns + 1, type_as=xs) Kp[:ns, :ns] = K # ls regu @@ -550,12 +563,12 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', # Kreg=I # RKHS regul - K0 = K1.T.dot(K1) + eta * Kp + K0 = nx.dot(K1.T, K1) + eta * Kp Kreg = Kp else: K1 = K - Id = np.eye(ns) + Id = nx.eye(ns, type_as=xs) # ls regul # K0 = K1.T.dot(K1)+eta*I @@ -568,7 +581,8 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', if log: log = {'err': []} - a, b = unif(ns), unif(nt) + a = unif(ns, type_as=xs) + b = unif(nt, type_as=xt) M = dist(xs, xt) * ns G = emd(a, b, M) @@ -576,28 +590,31 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', def loss(L, G): """Compute full loss""" - return np.sum((K1.dot(L) - ns * G.dot(xt)) ** 2) + mu * \ - np.sum(G * M) + eta * np.trace(L.T.dot(Kreg).dot(L)) + return ( + nx.sum((nx.dot(K1, L) - ns * nx.dot(G, xt)) ** 2) + + mu * nx.sum(G * M) + + eta * nx.trace(dots(L.T, Kreg, L)) + ) def solve_L_nobias(G): """ solve L problem with fixed G (least square)""" - xst = ns * G.dot(xt) - return np.linalg.solve(K0, xst) + xst = ns * nx.dot(G, xt) + return nx.solve(K0, xst) def solve_L_bias(G): """ solve L problem with fixed G (least square)""" - xst = ns * G.dot(xt) - return np.linalg.solve(K0, K1.T.dot(xst)) + xst = ns * nx.dot(G, xt) + return nx.solve(K0, nx.dot(K1.T, xst)) def solve_G(L, G0): """Update G with CG algorithm""" - xsi = K1.dot(L) + xsi = nx.dot(K1, L) def f(G): - return np.sum((xsi - ns * G.dot(xt)) ** 2) + return nx.sum((xsi - ns * nx.dot(G, xt)) ** 2) def df(G): - return -2 * ns * (xsi - ns * G.dot(xt)).dot(xt.T) + return -2 * ns * nx.dot(xsi - ns * nx.dot(G, xt), xt.T) G = cg(a, b, M, 1.0 / mu, f, df, G0=G0, numItermax=numInnerItermax, stopThr=stopInnerThr) @@ -681,15 +698,15 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, Parameters ---------- - xs : np.ndarray (ns,d) + xs : array-like (ns,d) samples in the source domain - xt : np.ndarray (nt,d) + xt : array-like (nt,d) samples in the target domain reg : float,optional regularization added to the diagonals of covariances (>0) - ws : np.ndarray (ns,1), optional + ws : array-like (ns,1), optional weights for the source samples - wt : np.ndarray (ns,1), optional + wt : array-like (ns,1), optional weights for the target samples bias: boolean, optional estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True) @@ -699,9 +716,9 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, Returns ------- - A : (d, d) ndarray + A : (d, d) array-like Linear operator - b : (1, d) ndarray + b : (1, d) array-like bias log : dict log dictionary return only if log==True in parameters @@ -719,36 +736,38 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, """ + xs, xt = list_to_array(xs, xt) + nx = get_backend(xs, xt) d = xs.shape[1] if bias: - mxs = xs.mean(0, keepdims=True) - mxt = xt.mean(0, keepdims=True) + mxs = nx.mean(xs, axis=0)[None, :] + mxt = nx.mean(xt, axis=0)[None, :] xs = xs - mxs xt = xt - mxt else: - mxs = np.zeros((1, d)) - mxt = np.zeros((1, d)) + mxs = nx.zeros((1, d), type_as=xs) + mxt = nx.zeros((1, d), type_as=xs) if ws is None: - ws = np.ones((xs.shape[0], 1)) / xs.shape[0] + ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0] if wt is None: - wt = np.ones((xt.shape[0], 1)) / xt.shape[0] + wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0] - Cs = (xs * ws).T.dot(xs) / ws.sum() + reg * np.eye(d) - Ct = (xt * wt).T.dot(xt) / wt.sum() + reg * np.eye(d) + Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) + reg * nx.eye(d, type_as=xs) + Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt) - Cs12 = linalg.sqrtm(Cs) - Cs_12 = linalg.inv(Cs12) + Cs12 = nx.sqrtm(Cs) + Cs_12 = nx.inv(Cs12) - M0 = linalg.sqrtm(Cs12.dot(Ct.dot(Cs12))) + M0 = nx.sqrtm(dots(Cs12, Ct, Cs12)) - A = Cs_12.dot(M0.dot(Cs_12)) + A = dots(Cs_12, M0, Cs_12) - b = mxt - mxs.dot(A) + b = mxt - nx.dot(mxs, A) if log: log = {} @@ -798,15 +817,15 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al Parameters ---------- - a : np.ndarray (ns,) + a : array-like (ns,) samples weights in the source domain - b : np.ndarray (nt,) + b : array-like (nt,) samples weights in the target domain - xs : np.ndarray (ns,d) + xs : array-like (ns,d) samples in the source domain - xt : np.ndarray (nt,d) + xt : array-like (nt,d) samples in the target domain - M : np.ndarray (ns,nt) + M : array-like (ns,nt) loss matrix sim : string, optional Type of similarity ('knn' or 'gauss') used to construct the Laplacian. @@ -834,7 +853,7 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al Returns ------- - gamma : (ns, nt) ndarray + gamma : (ns, nt) array-like Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -862,9 +881,12 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al raise ValueError( 'Similarity parameter should be an int or a float. Got {type} instead.'.format(type=type(sim_param).__name__)) + a, b, xs, xt, M = list_to_array(a, b, xs, xt, M) + nx = get_backend(a, b, xs, xt, M) + if sim == 'gauss': if sim_param is None: - sim_param = 1 / (2 * (np.mean(dist(xs, xs, 'sqeuclidean')) ** 2)) + sim_param = 1 / (2 * (nx.mean(dist(xs, xs, 'sqeuclidean')) ** 2)) sS = kernel(xs, xs, method=sim, sigma=sim_param) sT = kernel(xt, xt, method=sim, sigma=sim_param) @@ -874,9 +896,13 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al from sklearn.neighbors import kneighbors_graph - sS = kneighbors_graph(X=xs, n_neighbors=int(sim_param)).toarray() + sS = nx.from_numpy(kneighbors_graph( + X=nx.to_numpy(xs), n_neighbors=int(sim_param) + ).toarray(), type_as=xs) sS = (sS + sS.T) / 2 - sT = kneighbors_graph(xt, n_neighbors=int(sim_param)).toarray() + sT = nx.from_numpy(kneighbors_graph( + X=nx.to_numpy(xt), n_neighbors=int(sim_param) + ).toarray(), type_as=xt) sT = (sT + sT.T) / 2 else: raise ValueError('Unknown similarity type {sim}. Currently supported similarity types are "knn" and "gauss".'.format(sim=sim)) @@ -885,12 +911,14 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al lT = laplacian(sT) def f(G): - return alpha * np.trace(np.dot(xt.T, np.dot(G.T, np.dot(lS, np.dot(G, xt))))) \ - + (1 - alpha) * np.trace(np.dot(xs.T, np.dot(G, np.dot(lT, np.dot(G.T, xs))))) + return ( + alpha * nx.trace(dots(xt.T, G.T, lS, G, xt)) + + (1 - alpha) * nx.trace(dots(xs.T, G, lT, G.T, xs)) + ) ls2 = lS + lS.T lt2 = lT + lT.T - xt2 = np.dot(xt, xt.T) + xt2 = nx.dot(xt, xt.T) if reg == 'disp': Cs = -eta * alpha / xs.shape[0] * dots(ls2, xs, xt.T) @@ -898,8 +926,10 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al M = M + Cs + Ct def df(G): - return alpha * np.dot(ls2, np.dot(G, xt2))\ - + (1 - alpha) * np.dot(xs, np.dot(xs.T, np.dot(G, lt2))) + return ( + alpha * dots(ls2, G, xt2) + + (1 - alpha) * dots(xs, xs.T, G, lt2) + ) return cg(a, b, M, reg=eta, f=f, df=df, G0=None, numItermax=numItermax, numItermaxEmd=numInnerItermax, stopThr=stopThr, stopThr2=stopInnerThr, verbose=verbose, log=log) @@ -919,7 +949,7 @@ def distribution_estimation_uniform(X): The uniform distribution estimated from :math:`\mathbf{X}` """ - return unif(X.shape[0]) + return unif(X.shape[0], type_as=X) class BaseTransport(BaseEstimator): @@ -973,6 +1003,7 @@ class BaseTransport(BaseEstimator): self : object Returns self. """ + nx = self._get_backend(Xs, ys, Xt, yt) # check the necessary inputs parameters are here if check_params(Xs=Xs, Xt=Xt): @@ -984,14 +1015,14 @@ class BaseTransport(BaseEstimator): if (ys is not None) and (yt is not None): if self.limit_max != np.infty: - self.limit_max = self.limit_max * np.max(self.cost_) + self.limit_max = self.limit_max * nx.max(self.cost_) # assumes labeled source samples occupy the first rows # and labeled target samples occupy the first columns - classes = [c for c in np.unique(ys) if c != -1] + classes = [c for c in nx.unique(ys) if c != -1] for c in classes: - idx_s = np.where((ys != c) & (ys != -1)) - idx_t = np.where(yt == c) + idx_s = nx.where((ys != c) & (ys != -1)) + idx_t = nx.where(yt == c) # all the coefficients corresponding to a source sample # and a target sample : @@ -1062,23 +1093,24 @@ class BaseTransport(BaseEstimator): transp_Xs : array-like, shape (n_source_samples, n_features) The transport source samples. """ + nx = self.nx # check the necessary inputs parameters are here if check_params(Xs=Xs): - if np.array_equal(self.xs_, Xs): + if nx.array_equal(self.xs_, Xs): # perform standard barycentric mapping - transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None] + transp = self.coupling_ / nx.sum(self.coupling_, axis=1)[:, None] # set nans to 0 - transp[~ np.isfinite(transp)] = 0 + transp[~ nx.isfinite(transp)] = 0 # compute transported samples - transp_Xs = np.dot(transp, self.xt_) + transp_Xs = nx.dot(transp, self.xt_) else: # perform out of sample mapping - indices = np.arange(Xs.shape[0]) + indices = nx.arange(Xs.shape[0]) batch_ind = [ indices[i:i + batch_size] for i in range(0, len(indices), batch_size)] @@ -1087,20 +1119,20 @@ class BaseTransport(BaseEstimator): for bi in batch_ind: # get the nearest neighbor in the source domain D0 = dist(Xs[bi], self.xs_) - idx = np.argmin(D0, axis=1) + idx = nx.argmin(D0, axis=1) # transport the source samples - transp = self.coupling_ / np.sum( - self.coupling_, 1)[:, None] - transp[~ np.isfinite(transp)] = 0 - transp_Xs_ = np.dot(transp, self.xt_) + transp = self.coupling_ / nx.sum( + self.coupling_, axis=1)[:, None] + transp[~ nx.isfinite(transp)] = 0 + transp_Xs_ = nx.dot(transp, self.xt_) # define the transported points transp_Xs_ = transp_Xs_[idx, :] + Xs[bi] - self.xs_[idx, :] transp_Xs.append(transp_Xs_) - transp_Xs = np.concatenate(transp_Xs, axis=0) + transp_Xs = nx.concatenate(transp_Xs, axis=0) return transp_Xs @@ -1127,26 +1159,27 @@ class BaseTransport(BaseEstimator): International Conference on Artificial Intelligence and Statistics (AISTATS), 2019. """ + nx = self.nx # check the necessary inputs parameters are here if check_params(ys=ys): - ysTemp = label_normalization(np.copy(ys)) - classes = np.unique(ysTemp) + ysTemp = label_normalization(nx.copy(ys)) + classes = nx.unique(ysTemp) n = len(classes) - D1 = np.zeros((n, len(ysTemp))) + D1 = nx.zeros((n, len(ysTemp)), type_as=self.coupling_) # perform label propagation - transp = self.coupling_ / np.sum(self.coupling_, 0, keepdims=True) + transp = self.coupling_ / nx.sum(self.coupling_, axis=0)[None, :] # set nans to 0 - transp[~ np.isfinite(transp)] = 0 + transp[~ nx.isfinite(transp)] = 0 for c in classes: D1[int(c), ysTemp == c] = 1 # compute propagated labels - transp_ys = np.dot(D1, transp) + transp_ys = nx.dot(D1, transp) return transp_ys.T @@ -1176,23 +1209,24 @@ class BaseTransport(BaseEstimator): transp_Xt : array-like, shape (n_source_samples, n_features) The transported target samples. """ + nx = self.nx # check the necessary inputs parameters are here if check_params(Xt=Xt): - if np.array_equal(self.xt_, Xt): + if nx.array_equal(self.xt_, Xt): # perform standard barycentric mapping - transp_ = self.coupling_.T / np.sum(self.coupling_, 0)[:, None] + transp_ = self.coupling_.T / nx.sum(self.coupling_, 0)[:, None] # set nans to 0 - transp_[~ np.isfinite(transp_)] = 0 + transp_[~ nx.isfinite(transp_)] = 0 # compute transported samples - transp_Xt = np.dot(transp_, self.xs_) + transp_Xt = nx.dot(transp_, self.xs_) else: # perform out of sample mapping - indices = np.arange(Xt.shape[0]) + indices = nx.arange(Xt.shape[0]) batch_ind = [ indices[i:i + batch_size] for i in range(0, len(indices), batch_size)] @@ -1200,20 +1234,20 @@ class BaseTransport(BaseEstimator): transp_Xt = [] for bi in batch_ind: D0 = dist(Xt[bi], self.xt_) - idx = np.argmin(D0, axis=1) + idx = nx.argmin(D0, axis=1) # transport the target samples - transp_ = self.coupling_.T / np.sum( + transp_ = self.coupling_.T / nx.sum( self.coupling_, 0)[:, None] - transp_[~ np.isfinite(transp_)] = 0 - transp_Xt_ = np.dot(transp_, self.xs_) + transp_[~ nx.isfinite(transp_)] = 0 + transp_Xt_ = nx.dot(transp_, self.xs_) # define the transported points transp_Xt_ = transp_Xt_[idx, :] + Xt[bi] - self.xt_[idx, :] transp_Xt.append(transp_Xt_) - transp_Xt = np.concatenate(transp_Xt, axis=0) + transp_Xt = nx.concatenate(transp_Xt, axis=0) return transp_Xt @@ -1230,26 +1264,27 @@ class BaseTransport(BaseEstimator): transp_ys : array-like, shape (n_source_samples, nb_classes) Estimated soft source labels. """ + nx = self.nx # check the necessary inputs parameters are here if check_params(yt=yt): - ytTemp = label_normalization(np.copy(yt)) - classes = np.unique(ytTemp) + ytTemp = label_normalization(nx.copy(yt)) + classes = nx.unique(ytTemp) n = len(classes) - D1 = np.zeros((n, len(ytTemp))) + D1 = nx.zeros((n, len(ytTemp)), type_as=self.coupling_) # perform label propagation - transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None] + transp = self.coupling_ / nx.sum(self.coupling_, 1)[:, None] # set nans to 0 - transp[~ np.isfinite(transp)] = 0 + transp[~ nx.isfinite(transp)] = 0 for c in classes: D1[int(c), ytTemp == c] = 1 # compute propagated samples - transp_ys = np.dot(D1, transp.T) + transp_ys = nx.dot(D1, transp.T) return transp_ys.T @@ -1330,14 +1365,15 @@ class LinearTransport(BaseTransport): self : object Returns self. """ + nx = self._get_backend(Xs, ys, Xt, yt) self.mu_s = self.distribution_estimation(Xs) self.mu_t = self.distribution_estimation(Xt) # coupling estimation returned_ = OT_mapping_linear(Xs, Xt, reg=self.reg, - ws=self.mu_s.reshape((-1, 1)), - wt=self.mu_t.reshape((-1, 1)), + ws=nx.reshape(self.mu_s, (-1, 1)), + wt=nx.reshape(self.mu_t, (-1, 1)), bias=self.bias, log=self.log) # deal with the value of log @@ -1348,8 +1384,8 @@ class LinearTransport(BaseTransport): self.log_ = dict() # re compute inverse mapping - self.A1_ = linalg.inv(self.A_) - self.B1_ = -self.B_.dot(self.A1_) + self.A1_ = nx.inv(self.A_) + self.B1_ = -nx.dot(self.B_, self.A1_) return self @@ -1378,10 +1414,11 @@ class LinearTransport(BaseTransport): transp_Xs : array-like, shape (n_source_samples, n_features) The transport source samples. """ + nx = self.nx # check the necessary inputs parameters are here if check_params(Xs=Xs): - transp_Xs = Xs.dot(self.A_) + self.B_ + transp_Xs = nx.dot(Xs, self.A_) + self.B_ return transp_Xs @@ -1411,10 +1448,11 @@ class LinearTransport(BaseTransport): transp_Xt : array-like, shape (n_source_samples, n_features) The transported target samples. """ + nx = self.nx # check the necessary inputs parameters are here if check_params(Xt=Xt): - transp_Xt = Xt.dot(self.A1_) + self.B1_ + transp_Xt = nx.dot(Xt, self.A1_) + self.B1_ return transp_Xt @@ -2112,6 +2150,7 @@ class MappingTransport(BaseEstimator): self : object Returns self """ + self._get_backend(Xs, ys, Xt, yt) # check the necessary inputs parameters are here if check_params(Xs=Xs, Xt=Xt): @@ -2158,19 +2197,20 @@ class MappingTransport(BaseEstimator): transp_Xs : array-like, shape (n_source_samples, n_features) The transport source samples. """ + nx = self.nx # check the necessary inputs parameters are here if check_params(Xs=Xs): - if np.array_equal(self.xs_, Xs): + if nx.array_equal(self.xs_, Xs): # perform standard barycentric mapping - transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None] + transp = self.coupling_ / nx.sum(self.coupling_, 1)[:, None] # set nans to 0 - transp[~ np.isfinite(transp)] = 0 + transp[~ nx.isfinite(transp)] = 0 # compute transported samples - transp_Xs = np.dot(transp, self.xt_) + transp_Xs = nx.dot(transp, self.xt_) else: if self.kernel == "gaussian": K = kernel(Xs, self.xs_, method=self.kernel, @@ -2178,8 +2218,10 @@ class MappingTransport(BaseEstimator): elif self.kernel == "linear": K = Xs if self.bias: - K = np.hstack((K, np.ones((Xs.shape[0], 1)))) - transp_Xs = K.dot(self.mapping_) + K = nx.concatenate( + [K, nx.ones((Xs.shape[0], 1), type_as=K)], axis=1 + ) + transp_Xs = nx.dot(K, self.mapping_) return transp_Xs @@ -2396,6 +2438,7 @@ class JCPOTTransport(BaseTransport): self : object Returns self. """ + self._get_backend(*Xs, *ys, Xt, yt) # check the necessary inputs parameters are here if check_params(Xs=Xs, Xt=Xt, ys=ys): @@ -2438,28 +2481,29 @@ class JCPOTTransport(BaseTransport): batch_size : int, optional (default=128) The batch size for out of sample inverse transform """ + nx = self.nx transp_Xs = [] # check the necessary inputs parameters are here if check_params(Xs=Xs): - if all([np.allclose(x, y) for x, y in zip(self.xs_, Xs)]): + if all([nx.allclose(x, y) for x, y in zip(self.xs_, Xs)]): # perform standard barycentric mapping for each source domain for coupling in self.coupling_: - transp = coupling / np.sum(coupling, 1)[:, None] + transp = coupling / nx.sum(coupling, 1)[:, None] # set nans to 0 - transp[~ np.isfinite(transp)] = 0 + transp[~ nx.isfinite(transp)] = 0 # compute transported samples - transp_Xs.append(np.dot(transp, self.xt_)) + transp_Xs.append(nx.dot(transp, self.xt_)) else: # perform out of sample mapping - indices = np.arange(Xs.shape[0]) + indices = nx.arange(Xs.shape[0]) batch_ind = [ indices[i:i + batch_size] for i in range(0, len(indices), batch_size)] @@ -2470,23 +2514,22 @@ class JCPOTTransport(BaseTransport): transp_Xs_ = [] # get the nearest neighbor in the sources domains - xs = np.concatenate(self.xs_, axis=0) - idx = np.argmin(dist(Xs[bi], xs), axis=1) + xs = nx.concatenate(self.xs_, axis=0) + idx = nx.argmin(dist(Xs[bi], xs), axis=1) # transport the source samples for coupling in self.coupling_: - transp = coupling / np.sum( - coupling, 1)[:, None] - transp[~ np.isfinite(transp)] = 0 - transp_Xs_.append(np.dot(transp, self.xt_)) + transp = coupling / nx.sum(coupling, 1)[:, None] + transp[~ nx.isfinite(transp)] = 0 + transp_Xs_.append(nx.dot(transp, self.xt_)) - transp_Xs_ = np.concatenate(transp_Xs_, axis=0) + transp_Xs_ = nx.concatenate(transp_Xs_, axis=0) # define the transported points transp_Xs_ = transp_Xs_[idx, :] + Xs[bi] - xs[idx, :] transp_Xs.append(transp_Xs_) - transp_Xs = np.concatenate(transp_Xs, axis=0) + transp_Xs = nx.concatenate(transp_Xs, axis=0) return transp_Xs @@ -2512,32 +2555,36 @@ class JCPOTTransport(BaseTransport): "Optimal transport for multi-source domain adaptation under target shift", International Conference on Artificial Intelligence and Statistics (AISTATS), 2019. """ + nx = self.nx # check the necessary inputs parameters are here if check_params(ys=ys): - yt = np.zeros((len(np.unique(np.concatenate(ys))), self.xt_.shape[0])) + yt = nx.zeros( + (len(nx.unique(nx.concatenate(ys))), self.xt_.shape[0]), + type_as=ys[0] + ) for i in range(len(ys)): - ysTemp = label_normalization(np.copy(ys[i])) - classes = np.unique(ysTemp) + ysTemp = label_normalization(nx.copy(ys[i])) + classes = nx.unique(ysTemp) n = len(classes) ns = len(ysTemp) # perform label propagation - transp = self.coupling_[i] / np.sum(self.coupling_[i], 1)[:, None] + transp = self.coupling_[i] / nx.sum(self.coupling_[i], 1)[:, None] # set nans to 0 - transp[~ np.isfinite(transp)] = 0 + transp[~ nx.isfinite(transp)] = 0 if self.log: D1 = self.log_['D1'][i] else: - D1 = np.zeros((n, ns)) + D1 = nx.zeros((n, ns), type_as=transp) for c in classes: D1[int(c), ysTemp == c] = 1 # compute propagated labels - yt = yt + np.dot(D1, transp) / len(ys) + yt = yt + nx.dot(D1, transp) / len(ys) return yt.T @@ -2555,14 +2602,15 @@ class JCPOTTransport(BaseTransport): transp_ys : list of K array-like objects, shape K x (nk_source_samples, nb_classes) A list of estimated soft source labels """ + nx = self.nx # check the necessary inputs parameters are here if check_params(yt=yt): transp_ys = [] - ytTemp = label_normalization(np.copy(yt)) - classes = np.unique(ytTemp) + ytTemp = label_normalization(nx.copy(yt)) + classes = nx.unique(ytTemp) n = len(classes) - D1 = np.zeros((n, len(ytTemp))) + D1 = nx.zeros((n, len(ytTemp)), type_as=self.coupling_[0]) for c in classes: D1[int(c), ytTemp == c] = 1 @@ -2570,12 +2618,12 @@ class JCPOTTransport(BaseTransport): for i in range(len(self.xs_)): # perform label propagation - transp = self.coupling_[i] / np.sum(self.coupling_[i], 1)[:, None] + transp = self.coupling_[i] / nx.sum(self.coupling_[i], 1)[:, None] # set nans to 0 - transp[~ np.isfinite(transp)] = 0 + transp[~ nx.isfinite(transp)] = 0 # compute propagated labels - transp_ys.append(np.dot(D1, transp.T).T) + transp_ys.append(nx.dot(D1, transp.T).T) return transp_ys @@ -11,6 +11,7 @@ Dimension reduction with OT # Author: Remi Flamary <remi.flamary@unice.fr> # Minhui Huang <mhhuang@ucdavis.edu> +# Jakub Zadrozny <jakub.r.zadrozny@gmail.com> # # License: MIT License @@ -43,6 +44,28 @@ def sinkhorn(w1, w2, M, reg, k): return G +def logsumexp(M, axis): + r"""Log-sum-exp reduction compatible with autograd (no numpy implementation) + """ + amax = np.amax(M, axis=axis, keepdims=True) + return np.log(np.sum(np.exp(M - amax), axis=axis)) + np.squeeze(amax, axis=axis) + + +def sinkhorn_log(w1, w2, M, reg, k): + r"""Sinkhorn algorithm in log-domain with fixed number of iteration (autograd) + """ + Mr = -M / reg + ui = np.zeros((M.shape[0],)) + vi = np.zeros((M.shape[1],)) + log_w1 = np.log(w1) + log_w2 = np.log(w2) + for i in range(k): + vi = log_w2 - logsumexp(Mr + ui[:, None], 0) + ui = log_w1 - logsumexp(Mr + vi[None, :], 1) + G = np.exp(ui[:, None] + Mr + vi[None, :]) + return G + + def split_classes(X, y): r"""split samples in :math:`\mathbf{X}` by classes in :math:`\mathbf{y}` """ @@ -110,7 +133,7 @@ def fda(X, y, p=2, reg=1e-16): return Popt, proj -def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, normalize=False): +def wda(X, y, p=2, reg=1, k=10, solver=None, sinkhorn_method='sinkhorn', maxiter=100, verbose=0, P0=None, normalize=False): r""" Wasserstein Discriminant Analysis :ref:`[11] <references-wda>` @@ -126,6 +149,14 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no - :math:`W` is entropic regularized Wasserstein distances - :math:`\mathbf{X}^i` are samples in the dataset corresponding to class i + **Choosing a Sinkhorn solver** + + By default and when using a regularization parameter that is not too small + the default sinkhorn solver should be enough. If you need to use a small + regularization to get sparse cost matrices, you should use the + :py:func:`ot.dr.sinkhorn_log` solver that will avoid numerical + errors, but can be slow in practice. + Parameters ---------- X : ndarray, shape (n, d) @@ -139,6 +170,8 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no solver : None | str, optional None for steepest descent or 'TrustRegions' for trust regions algorithm else should be a pymanopt.solvers + sinkhorn_method : str + method used for the Sinkhorn solver, either 'sinkhorn' or 'sinkhorn_log' P0 : ndarray, shape (d, p) Initial starting point for projection. normalize : bool, optional @@ -161,6 +194,13 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063. """ # noqa + if sinkhorn_method.lower() == 'sinkhorn': + sinkhorn_solver = sinkhorn + elif sinkhorn_method.lower() == 'sinkhorn_log': + sinkhorn_solver = sinkhorn_log + else: + raise ValueError("Unknown Sinkhorn method '%s'." % sinkhorn_method) + mx = np.mean(X) X -= mx.reshape((1, -1)) @@ -193,7 +233,7 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no for j, xj in enumerate(xc[i:]): xj = np.dot(xj, P) M = dist(xi, xj) - G = sinkhorn(wc[i], wc[j + i], M, reg * regmean[i, j], k) + G = sinkhorn_solver(wc[i], wc[j + i], M, reg * regmean[i, j], k) if j == 0: loss_w += np.sum(G * M) else: diff --git a/ot/factored.py b/ot/factored.py new file mode 100644 index 0000000..abc2445 --- /dev/null +++ b/ot/factored.py @@ -0,0 +1,145 @@ +""" +Factored OT solvers (low rank, cost or OT plan) +""" + +# Author: Remi Flamary <remi.flamary@polytehnique.edu> +# +# License: MIT License + +from .backend import get_backend +from .utils import dist +from .lp import emd +from .bregman import sinkhorn + +__all__ = ['factored_optimal_transport'] + + +def factored_optimal_transport(Xa, Xb, a=None, b=None, reg=0.0, r=100, X0=None, stopThr=1e-7, numItermax=100, verbose=False, log=False, **kwargs): + r"""Solves factored OT problem and return OT plans and intermediate distribution + + This function solve the following OT problem [40]_ + + .. math:: + \mathop{\arg \min}_\mu \quad W_2^2(\mu_a,\mu)+ W_2^2(\mu,\mu_b) + + where : + + - :math:`\mu_a` and :math:`\mu_b` are empirical distributions. + - :math:`\mu` is an empirical distribution with r samples + + And returns the two OT plans between + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + + Uses the conditional gradient algorithm to solve the problem proposed in + :ref:`[39] <references-weak>`. + + Parameters + ---------- + Xa : (ns,d) array-like, float + Source samples + Xb : (nt,d) array-like, float + Target samples + a : (ns,) array-like, float + Source histogram (uniform weight if empty list) + b : (nt,) array-like, float + Target histogram (uniform weight if empty list)) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on the relative variation (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + Ga: array-like, shape (ns, r) + Optimal transportation matrix between source and the intermediate + distribution + Gb: array-like, shape (r, nt) + Optimal transportation matrix between the intermediate and target + distribution + X: array-like, shape (r, d) + Support of the intermediate distribution + log: dict, optional + If input log is true, a dictionary containing the cost and dual + variables and exit status + + + .. _references-factored: + References + ---------- + .. [40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, + G., & Weed, J. (2019, April). Statistical optimal transport via factored + couplings. In The 22nd International Conference on Artificial + Intelligence and Statistics (pp. 2454-2465). PMLR. + + See Also + -------- + ot.bregman.sinkhorn : Entropic regularized OT ot.optim.cg : General + regularized OT + """ + + nx = get_backend(Xa, Xb) + + n_a = Xa.shape[0] + n_b = Xb.shape[0] + d = Xa.shape[1] + + if a is None: + a = nx.ones((n_a), type_as=Xa) / n_a + if b is None: + b = nx.ones((n_b), type_as=Xb) / n_b + + if X0 is None: + X = nx.randn(r, d, type_as=Xa) + else: + X = X0 + + w = nx.ones(r, type_as=Xa) / r + + def solve_ot(X1, X2, w1, w2): + M = dist(X1, X2) + if reg > 0: + G, log = sinkhorn(w1, w2, M, reg, log=True, **kwargs) + log['cost'] = nx.sum(G * M) + return G, log + else: + return emd(w1, w2, M, log=True, **kwargs) + + norm_delta = [] + + # solve the barycenter + for i in range(numItermax): + + old_X = X + + # solve OT with template + Ga, loga = solve_ot(Xa, X, a, w) + Gb, logb = solve_ot(X, Xb, w, b) + + X = 0.5 * (nx.dot(Ga.T, Xa) + nx.dot(Gb, Xb)) * r + + delta = nx.norm(X - old_X) + if delta < stopThr: + break + if log: + norm_delta.append(delta) + + if log: + log_dic = {'delta_iter': norm_delta, + 'ua': loga['u'], + 'va': loga['v'], + 'ub': logb['u'], + 'vb': logb['v'], + 'costa': loga['cost'], + 'costb': logb['cost'], + } + return Ga, Gb, X, log_dic + + return Ga, Gb, X diff --git a/ot/gpu/__init__.py b/ot/gpu/__init__.py deleted file mode 100644 index 12db605..0000000 --- a/ot/gpu/__init__.py +++ /dev/null @@ -1,50 +0,0 @@ -# -*- coding: utf-8 -*- -""" -GPU implementation for several OT solvers and utility -functions. - -The GPU backend in handled by `cupy -<https://cupy.chainer.org/>`_. - -.. warning:: - This module is now deprecated and will be removed in future releases. POT - now privides a backend mechanism that allows for solving prolem on GPU wth - the pytorch backend. - - -.. warning:: - Note that by default the module is not imported in :mod:`ot`. In order to - use it you need to explicitely import :mod:`ot.gpu` . - -By default, the functions in this module accept and return numpy arrays -in order to proide drop-in replacement for the other POT function but -the transfer between CPU en GPU comes with a significant overhead. - -In order to get the best performances, we recommend to give only cupy -arrays to the functions and desactivate the conversion to numpy of the -result of the function with parameter ``to_numpy=False``. - -""" - -# Author: Remi Flamary <remi.flamary@unice.fr> -# Leo Gautheron <https://github.com/aje> -# -# License: MIT License - -import warnings - -from . import bregman -from . import da -from .bregman import sinkhorn -from .da import sinkhorn_lpl1_mm - -from . import utils -from .utils import dist, to_gpu, to_np - - -warnings.warn('This module is deprecated and will be removed in the next minor release of POT', category=DeprecationWarning) - - -__all__ = ["utils", "dist", "sinkhorn", - "sinkhorn_lpl1_mm", 'bregman', 'da', 'to_gpu', 'to_np'] - diff --git a/ot/gpu/bregman.py b/ot/gpu/bregman.py deleted file mode 100644 index 76af00e..0000000 --- a/ot/gpu/bregman.py +++ /dev/null @@ -1,196 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Bregman projections for regularized OT with GPU -""" - -# Author: Remi Flamary <remi.flamary@unice.fr> -# Leo Gautheron <https://github.com/aje> -# -# License: MIT License - -import cupy as np # np used for matrix computation -import cupy as cp # cp used for cupy specific operations -from . import utils - - -def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, - verbose=False, log=False, to_numpy=True, **kwargs): - r""" - Solve the entropic regularization optimal transport on GPU - - If the input matrix are in numpy format, they will be uploaded to the - GPU first which can incur significant time overhead. - - The function solves the following optimization problem: - - .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) - - s.t. \gamma 1 = a - - \gamma^T 1= b - - \gamma\geq 0 - where : - - - M is the (ns,nt) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights (sum to 1) - - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_ - - - Parameters - ---------- - a : np.ndarray (ns,) - samples weights in the source domain - b : np.ndarray (nt,) or np.ndarray (nt,nbb) - samples in the target domain, compute sinkhorn with multiple targets - and fixed M if b is a matrix (return OT loss + dual variables in log) - M : np.ndarray (ns,nt) - loss matrix - reg : float - Regularization term >0 - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - to_numpy : boolean, optional (default True) - If true convert back the GPU array result to numpy format. - - - Returns - ------- - gamma : (ns x nt) ndarray - Optimal transportation matrix for the given parameters - log : dict - log dictionary return only if log==True in parameters - - - References - ---------- - - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 - - - See Also - -------- - ot.lp.emd : Unregularized OT - ot.optim.cg : General regularized OT - - """ - - a = cp.asarray(a) - b = cp.asarray(b) - M = cp.asarray(M) - - if len(a) == 0: - a = np.ones((M.shape[0],)) / M.shape[0] - if len(b) == 0: - b = np.ones((M.shape[1],)) / M.shape[1] - - # init data - Nini = len(a) - Nfin = len(b) - - if len(b.shape) > 1: - nbb = b.shape[1] - else: - nbb = 0 - - if log: - log = {'err': []} - - # we assume that no distances are null except those of the diagonal of - # distances - if nbb: - u = np.ones((Nini, nbb)) / Nini - v = np.ones((Nfin, nbb)) / Nfin - else: - u = np.ones(Nini) / Nini - v = np.ones(Nfin) / Nfin - - # print(reg) - - # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute - K = np.empty(M.shape, dtype=M.dtype) - np.divide(M, -reg, out=K) - np.exp(K, out=K) - - # print(np.min(K)) - tmp2 = np.empty(b.shape, dtype=M.dtype) - - Kp = (1 / a).reshape(-1, 1) * K - cpt = 0 - err = 1 - while (err > stopThr and cpt < numItermax): - uprev = u - vprev = v - - KtransposeU = np.dot(K.T, u) - v = np.divide(b, KtransposeU) - u = 1. / np.dot(Kp, v) - - if (np.any(KtransposeU == 0) or - np.any(np.isnan(u)) or np.any(np.isnan(v)) or - np.any(np.isinf(u)) or np.any(np.isinf(v))): - # we have reached the machine precision - # come back to previous solution and quit loop - print('Warning: numerical errors at iteration', cpt) - u = uprev - v = vprev - break - if cpt % 10 == 0: - # we can speed up the process by checking for the error only all - # the 10th iterations - if nbb: - err = np.sqrt( - np.sum((u - uprev)**2) / np.sum((u)**2) - + np.sum((v - vprev)**2) / np.sum((v)**2) - ) - else: - # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 - tmp2 = np.sum(u[:, None] * K * v[None, :], 0) - #tmp2=np.einsum('i,ij,j->j', u, K, v) - err = np.linalg.norm(tmp2 - b) # violation of marginal - if log: - log['err'].append(err) - - if verbose: - if cpt % 200 == 0: - print( - '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - cpt = cpt + 1 - if log: - log['u'] = u - log['v'] = v - - if nbb: # return only loss - #res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) (explodes cupy memory) - res = np.empty(nbb) - for i in range(nbb): - res[i] = np.sum(u[:, None, i] * (K * M) * v[None, :, i]) - if to_numpy: - res = utils.to_np(res) - if log: - return res, log - else: - return res - - else: # return OT matrix - res = u.reshape((-1, 1)) * K * v.reshape((1, -1)) - if to_numpy: - res = utils.to_np(res) - if log: - return res, log - else: - return res - - -# define sinkhorn as sinkhorn_knopp -sinkhorn = sinkhorn_knopp diff --git a/ot/gpu/da.py b/ot/gpu/da.py deleted file mode 100644 index 7adb830..0000000 --- a/ot/gpu/da.py +++ /dev/null @@ -1,144 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Domain adaptation with optimal transport with GPU implementation -""" - -# Author: Remi Flamary <remi.flamary@unice.fr> -# Nicolas Courty <ncourty@irisa.fr> -# Michael Perrot <michael.perrot@univ-st-etienne.fr> -# Leo Gautheron <https://github.com/aje> -# -# License: MIT License - - -import cupy as np # np used for matrix computation -import cupy as cp # cp used for cupy specific operations -import numpy as npp -from . import utils - -from .bregman import sinkhorn - - -def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, - numInnerItermax=200, stopInnerThr=1e-9, verbose=False, - log=False, to_numpy=True): - """ - Solve the entropic regularization optimal transport problem with nonconvex - group lasso regularization on GPU - - If the input matrix are in numpy format, they will be uploaded to the - GPU first which can incur significant time overhead. - - - The function solves the following optimization problem: - - .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma) - + \eta \Omega_g(\gamma) - - s.t. \gamma 1 = a - - \gamma^T 1= b - - \gamma\geq 0 - where : - - - M is the (ns,nt) metric cost matrix - - :math:`\Omega_e` is the entropic regularization term - :math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\Omega_g` is the group lasso regulaization term - :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^{1/2}_1` - where :math:`\mathcal{I}_c` are the index of samples from class c - in the source domain. - - a and b are source and target weights (sum to 1) - - The algorithm used for solving the problem is the generalised conditional - gradient as proposed in [5]_ [7]_ - - - Parameters - ---------- - a : np.ndarray (ns,) - samples weights in the source domain - labels_a : np.ndarray (ns,) - labels of samples in the source domain - b : np.ndarray (nt,) - samples weights in the target domain - M : np.ndarray (ns,nt) - loss matrix - reg : float - Regularization term for entropic regularization >0 - eta : float, optional - Regularization term for group lasso regularization >0 - numItermax : int, optional - Max number of iterations - numInnerItermax : int, optional - Max number of iterations (inner sinkhorn solver) - stopInnerThr : float, optional - Stop threshold on error (inner sinkhorn solver) (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - to_numpy : boolean, optional (default True) - If true convert back the GPU array result to numpy format. - - - Returns - ------- - gamma : (ns x nt) ndarray - Optimal transportation matrix for the given parameters - log : dict - log dictionary return only if log==True in parameters - - - References - ---------- - - .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, - "Optimal Transport for Domain Adaptation," in IEEE - Transactions on Pattern Analysis and Machine Intelligence , - vol.PP, no.99, pp.1-1 - .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). - Generalized conditional gradient: analysis of convergence - and applications. arXiv preprint arXiv:1510.06567. - - See Also - -------- - ot.lp.emd : Unregularized OT - ot.bregman.sinkhorn : Entropic regularized OT - ot.optim.cg : General regularized OT - - """ - - a, labels_a, b, M = utils.to_gpu(a, labels_a, b, M) - - p = 0.5 - epsilon = 1e-3 - - indices_labels = [] - labels_a2 = cp.asnumpy(labels_a) - classes = npp.unique(labels_a2) - for c in classes: - idxc = utils.to_gpu(*npp.where(labels_a2 == c)) - indices_labels.append(idxc) - - W = np.zeros(M.shape) - - for cpt in range(numItermax): - Mreg = M + eta * W - transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax, - stopThr=stopInnerThr, to_numpy=False) - # the transport has been computed. Check if classes are really - # separated - W = np.ones(M.shape) - for (i, c) in enumerate(classes): - - majs = np.sum(transp[indices_labels[i]], axis=0) - majs = p * ((majs + epsilon)**(p - 1)) - W[indices_labels[i]] = majs - - if to_numpy: - return utils.to_np(transp) - else: - return transp diff --git a/ot/gpu/utils.py b/ot/gpu/utils.py deleted file mode 100644 index 41e168a..0000000 --- a/ot/gpu/utils.py +++ /dev/null @@ -1,101 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Utility functions for GPU -""" - -# Author: Remi Flamary <remi.flamary@unice.fr> -# Nicolas Courty <ncourty@irisa.fr> -# Leo Gautheron <https://github.com/aje> -# -# License: MIT License - -import cupy as np # np used for matrix computation -import cupy as cp # cp used for cupy specific operations - - -def euclidean_distances(a, b, squared=False, to_numpy=True): - """ - Compute the pairwise euclidean distance between matrices a and b. - - If the input matrix are in numpy format, they will be uploaded to the - GPU first which can incur significant time overhead. - - Parameters - ---------- - a : np.ndarray (n, f) - first matrix - b : np.ndarray (m, f) - second matrix - to_numpy : boolean, optional (default True) - If true convert back the GPU array result to numpy format. - squared : boolean, optional (default False) - if True, return squared euclidean distance matrix - - Returns - ------- - c : (n x m) np.ndarray or cupy.ndarray - pairwise euclidean distance distance matrix - """ - - a, b = to_gpu(a, b) - - a2 = np.sum(np.square(a), 1) - b2 = np.sum(np.square(b), 1) - - c = -2 * np.dot(a, b.T) - c += a2[:, None] - c += b2[None, :] - - if not squared: - np.sqrt(c, out=c) - if to_numpy: - return to_np(c) - else: - return c - - -def dist(x1, x2=None, metric='sqeuclidean', to_numpy=True): - """Compute distance between samples in x1 and x2 on gpu - - Parameters - ---------- - - x1 : np.array (n1,d) - matrix with n1 samples of size d - x2 : np.array (n2,d), optional - matrix with n2 samples of size d (if None then x2=x1) - metric : str - Metric from 'sqeuclidean', 'euclidean', - - - Returns - ------- - - M : np.array (n1,n2) - distance matrix computed with given metric - - """ - if x2 is None: - x2 = x1 - if metric == "sqeuclidean": - return euclidean_distances(x1, x2, squared=True, to_numpy=to_numpy) - elif metric == "euclidean": - return euclidean_distances(x1, x2, squared=False, to_numpy=to_numpy) - else: - raise NotImplementedError - - -def to_gpu(*args): - """ Upload numpy arrays to GPU and return them""" - if len(args) > 1: - return (cp.asarray(x) for x in args) - else: - return cp.asarray(args[0]) - - -def to_np(*args): - """ convert GPU arras to numpy and return them""" - if len(args) > 1: - return (cp.asnumpy(x) for x in args) - else: - return cp.asnumpy(args[0]) diff --git a/ot/gromov.py b/ot/gromov.py index 6544260..55ab0bd 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -7,6 +7,7 @@ Gromov-Wasserstein and Fused-Gromov-Wasserstein solvers # Nicolas Courty <ncourty@irisa.fr>
# Rémi Flamary <remi.flamary@unice.fr>
# Titouan Vayer <titouan.vayer@irisa.fr>
+# Cédric Vincent-Cuaz <cedric.vincent-cuaz@inria.fr>
#
# License: MIT License
@@ -17,7 +18,7 @@ from .bregman import sinkhorn from .utils import dist, UndefinedParameter, list_to_array
from .optim import cg
from .lp import emd_1d, emd
-from .utils import check_random_state
+from .utils import check_random_state, unif
from .backend import get_backend
@@ -320,7 +321,7 @@ def update_kl_loss(p, lambdas, T, Cs): return nx.exp(tmpsum / ppt)
-def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, **kwargs):
+def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, G0=None, **kwargs):
r"""
Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
@@ -338,6 +339,10 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F - :math:`\mathbf{q}`: distribution in the target space
- `L`: loss function to account for the misfit between the similarity matrices
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
Parameters
----------
C1 : array-like, shape (ns, ns)
@@ -361,6 +366,9 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F armijo : bool, optional
If True the step of the line-search is found via an armijo research. Else closed form is used.
If there are convergence issues use False.
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
**kwargs : dict
parameters can be directly passed to the ot.optim.cg solver
@@ -385,18 +393,26 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F """
p, q = list_to_array(p, q)
-
p0, q0, C10, C20 = p, q, C1, C2
- nx = get_backend(p0, q0, C10, C20)
-
+ if G0 is None:
+ nx = get_backend(p0, q0, C10, C20)
+ else:
+ G0_ = G0
+ nx = get_backend(p0, q0, C10, C20, G0_)
p = nx.to_numpy(p)
q = nx.to_numpy(q)
C1 = nx.to_numpy(C10)
C2 = nx.to_numpy(C20)
- constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
+ if G0 is None:
+ G0 = p[:, None] * q[None, :]
+ else:
+ G0 = nx.to_numpy(G0_)
+ # Check marginals of G0
+ np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08)
+ np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08)
- G0 = p[:, None] * q[None, :]
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
def f(G):
return gwloss(constC, hC1, hC2, G)
@@ -414,7 +430,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F return nx.from_numpy(cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=False, **kwargs), type_as=C10)
-def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, **kwargs):
+def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, G0=None, **kwargs):
r"""
Returns the gromov-wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
@@ -436,6 +452,10 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo= Note that when using backends, this loss function is differentiable wrt the
marices and weights for quadratic loss using the gradients from [38]_.
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
Parameters
----------
C1 : array-like, shape (ns, ns)
@@ -459,6 +479,9 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo= armijo : bool, optional
If True the step of the line-search is found via an armijo research. Else closed form is used.
If there are convergence issues use False.
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
Returns
-------
@@ -483,9 +506,12 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo= """
p, q = list_to_array(p, q)
-
p0, q0, C10, C20 = p, q, C1, C2
- nx = get_backend(p0, q0, C10, C20)
+ if G0 is None:
+ nx = get_backend(p0, q0, C10, C20)
+ else:
+ G0_ = G0
+ nx = get_backend(p0, q0, C10, C20, G0_)
p = nx.to_numpy(p)
q = nx.to_numpy(q)
@@ -494,7 +520,13 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo= constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
- G0 = p[:, None] * q[None, :]
+ if G0 is None:
+ G0 = p[:, None] * q[None, :]
+ else:
+ G0 = nx.to_numpy(G0_)
+ # Check marginals of G0
+ np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08)
+ np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08)
def f(G):
return gwloss(constC, hC1, hC2, G)
@@ -514,10 +546,13 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo= gw = log_gw['gw_dist']
if loss_fun == 'square_loss':
- gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T))
- gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T))
+ gC1 = 2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T)
+ gC2 = 2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T)
+ gC1 = nx.from_numpy(gC1, type_as=C10)
+ gC2 = nx.from_numpy(gC2, type_as=C10)
gw = nx.set_gradients(gw, (p0, q0, C10, C20),
- (log_gw['u'], log_gw['v'], gC1, gC2))
+ (log_gw['u'] - nx.mean(log_gw['u']),
+ log_gw['v'] - nx.mean(log_gw['v']), gC1, gC2))
if log:
return gw, log_gw
@@ -525,7 +560,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo= return gw
-def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
+def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, G0=None, log=False, **kwargs):
r"""
Computes the FGW transport between two graphs (see :ref:`[24] <references-fused-gromov-wasserstein>`)
@@ -545,6 +580,10 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1)
- `L` is a loss function to account for the misfit between the similarity matrices
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] <references-fused-gromov-wasserstein>`
Parameters
@@ -566,6 +605,9 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo : bool, optional
If True the step of the line-search is found via an armijo research. Else closed form is used.
If there are convergence issues use False.
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
log : bool, optional
record log if True
**kwargs : dict
@@ -588,20 +630,28 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, (ICML). 2019.
"""
p, q = list_to_array(p, q)
-
p0, q0, C10, C20, M0 = p, q, C1, C2, M
- nx = get_backend(p0, q0, C10, C20, M0)
+ if G0 is None:
+ nx = get_backend(p0, q0, C10, C20, M0)
+ else:
+ G0_ = G0
+ nx = get_backend(p0, q0, C10, C20, M0, G0_)
p = nx.to_numpy(p)
q = nx.to_numpy(q)
C1 = nx.to_numpy(C10)
C2 = nx.to_numpy(C20)
M = nx.to_numpy(M0)
+ if G0 is None:
+ G0 = p[:, None] * q[None, :]
+ else:
+ G0 = nx.to_numpy(G0_)
+ # Check marginals of G0
+ np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08)
+ np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08)
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
- G0 = p[:, None] * q[None, :]
-
def f(G):
return gwloss(constC, hC1, hC2, G)
@@ -610,19 +660,16 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, if log:
res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
-
fgw_dist = nx.from_numpy(log['loss'][-1], type_as=C10)
-
log['fgw_dist'] = fgw_dist
log['u'] = nx.from_numpy(log['u'], type_as=C10)
log['v'] = nx.from_numpy(log['v'], type_as=C10)
return nx.from_numpy(res, type_as=C10), log
-
else:
return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs), type_as=C10)
-def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
+def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, G0=None, log=False, **kwargs):
r"""
Computes the FGW distance between two graphs see (see :ref:`[24] <references-fused-gromov-wasserstein2>`)
@@ -645,6 +692,10 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 The algorithm used for solving the problem is conditional gradient as
discussed in :ref:`[24] <references-fused-gromov-wasserstein2>`
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
Note that when using backends, this loss function is differentiable wrt the
marices and weights for quadratic loss using the gradients from [38]_.
@@ -667,6 +718,9 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 armijo : bool, optional
If True the step of the line-search is found via an armijo research.
Else closed form is used. If there are convergence issues use False.
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
log : bool, optional
Record log if True.
**kwargs : dict
@@ -695,7 +749,11 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 p, q = list_to_array(p, q)
p0, q0, C10, C20, M0 = p, q, C1, C2, M
- nx = get_backend(p0, q0, C10, C20, M0)
+ if G0 is None:
+ nx = get_backend(p0, q0, C10, C20, M0)
+ else:
+ G0_ = G0
+ nx = get_backend(p0, q0, C10, C20, M0, G0_)
p = nx.to_numpy(p)
q = nx.to_numpy(q)
@@ -705,7 +763,13 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
- G0 = p[:, None] * q[None, :]
+ if G0 is None:
+ G0 = p[:, None] * q[None, :]
+ else:
+ G0 = nx.to_numpy(G0_)
+ # Check marginals of G0
+ np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08)
+ np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08)
def f(G):
return gwloss(constC, hC1, hC2, G)
@@ -725,10 +789,14 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 log_fgw['T'] = T0
if loss_fun == 'square_loss':
- gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T))
- gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T))
+ gC1 = 2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T)
+ gC2 = 2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T)
+ gC1 = nx.from_numpy(gC1, type_as=C10)
+ gC2 = nx.from_numpy(gC2, type_as=C10)
fgw_dist = nx.set_gradients(fgw_dist, (p0, q0, C10, C20, M0),
- (log_fgw['u'], log_fgw['v'], alpha * gC1, alpha * gC2, (1 - alpha) * T0))
+ (log_fgw['u'] - nx.mean(log_fgw['u']),
+ log_fgw['v'] - nx.mean(log_fgw['v']),
+ alpha * gC1, alpha * gC2, (1 - alpha) * T0))
if log:
return fgw_dist, log_fgw
@@ -1780,3 +1848,988 @@ def update_feature_matrix(lambdas, Ys, Ts, p): for s in range(len(Ts))
])
return tmpsum
+
+
+def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate=1., Cdict_init=None, projection='nonnegative_symmetric', use_log=True,
+ tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs):
+ r"""
+ Infer Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, q) \}_{d \in [D]}` from the list of structures :math:`\{ (\mathbf{C_s},\mathbf{p_s}) \}_s`
+
+ .. math::
+ \min_{\mathbf{C_{dict}}, \{\mathbf{w_s} \}_{s \leq S}} \sum_{s=1}^S GW_2(\mathbf{C_s}, \sum_{d=1}^D w_{s,d}\mathbf{C_{dict}[d]}, \mathbf{p_s}, \mathbf{q}) - reg\| \mathbf{w_s} \|_2^2
+
+ such that, :math:`\forall s \leq S` :
+
+ - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1`
+ - :math:`\mathbf{w_s} \geq \mathbf{0}_D`
+
+ Where :
+
+ - :math:`\forall s \leq S, \mathbf{C_s}` is a (ns,ns) pairwise similarity matrix of variable size ns.
+ - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt.
+ - :math:`\forall s \leq S, \mathbf{p_s}` is the source distribution corresponding to :math:`\mathbf{C_s}`
+ - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space.
+ - reg is the regularization coefficient.
+
+ The stochastic algorithm used for estimating the graph dictionary atoms as proposed in [38]
+
+ Parameters
+ ----------
+ Cs : list of S symmetric array-like, shape (ns, ns)
+ List of Metric/Graph cost matrices of variable size (ns, ns).
+ D: int
+ Number of dictionary atoms to learn
+ nt: int
+ Number of samples within each dictionary atoms
+ reg : float, optional
+ Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0.
+ ps : list of S array-like, shape (ns,), optional
+ Distribution in each source space C of Cs. Default is None and corresponds to uniform distibutions.
+ q : array-like, shape (nt,), optional
+ Distribution in the embedding space whose structure will be learned. Default is None and corresponds to uniform distributions.
+ epochs: int, optional
+ Number of epochs used to learn the dictionary. Default is 32.
+ batch_size: int, optional
+ Batch size for each stochastic gradient update of the dictionary. Set to the dataset size if the provided batch_size is higher than the dataset size. Default is 32.
+ learning_rate: float, optional
+ Learning rate used for the stochastic gradient descent. Default is 1.
+ Cdict_init: list of D array-like with shape (nt, nt), optional
+ Used to initialize the dictionary.
+ If set to None (Default), the dictionary will be initialized randomly.
+ Else Cdict must have shape (D, nt, nt) i.e match provided shape features.
+ projection: str , optional
+ If 'nonnegative' and/or 'symmetric' is in projection, the corresponding projection will be performed at each stochastic update of the dictionary
+ Else the set of atoms is :math:`R^{nt * nt}`. Default is 'nonnegative_symmetric'
+ log: bool, optional
+ If set to True, losses evolution by batches and epochs are tracked. Default is False.
+ use_adam_optimizer: bool, optional
+ If set to True, adam optimizer with default settings is used as adaptative learning rate strategy.
+ Else perform SGD with fixed learning rate. Default is True.
+ tol_outer : float, optional
+ Solver precision for the BCD algorithm, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`.
+ tol_inner : float, optional
+ Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`.
+ max_iter_outer : int, optional
+ Maximum number of iterations for the BCD. Default is 20.
+ max_iter_inner : int, optional
+ Maximum number of iterations for the Conjugate Gradient. Default is 200.
+ verbose : bool, optional
+ Print the reconstruction loss every epoch. Default is False.
+
+ Returns
+ -------
+
+ Cdict_best_state : D array-like, shape (D,nt,nt)
+ Metric/Graph cost matrices composing the dictionary.
+ The dictionary leading to the best loss over an epoch is saved and returned.
+ log: dict
+ If use_log is True, contains loss evolutions by batches and epochs.
+ References
+ -------
+
+ ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty.
+ "Online Graph Dictionary Learning"
+ International Conference on Machine Learning (ICML). 2021.
+ """
+ # Handle backend of non-optional arguments
+ Cs0 = Cs
+ nx = get_backend(*Cs0)
+ Cs = [nx.to_numpy(C) for C in Cs0]
+ dataset_size = len(Cs)
+ # Handle backend of optional arguments
+ if ps is None:
+ ps = [unif(C.shape[0]) for C in Cs]
+ else:
+ ps = [nx.to_numpy(p) for p in ps]
+ if q is None:
+ q = unif(nt)
+ else:
+ q = nx.to_numpy(q)
+ if Cdict_init is None:
+ # Initialize randomly structures of dictionary atoms based on samples
+ dataset_means = [C.mean() for C in Cs]
+ Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt))
+ else:
+ Cdict = nx.to_numpy(Cdict_init).copy()
+ assert Cdict.shape == (D, nt, nt)
+
+ if 'symmetric' in projection:
+ Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1)))
+ if 'nonnegative' in projection:
+ Cdict[Cdict < 0.] = 0
+ if use_adam_optimizer:
+ adam_moments = _initialize_adam_optimizer(Cdict)
+
+ log = {'loss_batches': [], 'loss_epochs': []}
+ const_q = q[:, None] * q[None, :]
+ Cdict_best_state = Cdict.copy()
+ loss_best_state = np.inf
+ if batch_size > dataset_size:
+ batch_size = dataset_size
+ iter_by_epoch = dataset_size // batch_size + int((dataset_size % batch_size) > 0)
+
+ for epoch in range(epochs):
+ cumulated_loss_over_epoch = 0.
+
+ for _ in range(iter_by_epoch):
+ # batch sampling
+ batch = np.random.choice(range(dataset_size), size=batch_size, replace=False)
+ cumulated_loss_over_batch = 0.
+ unmixings = np.zeros((batch_size, D))
+ Cs_embedded = np.zeros((batch_size, nt, nt))
+ Ts = [None] * batch_size
+
+ for batch_idx, C_idx in enumerate(batch):
+ # BCD solver for Gromov-Wassersteisn linear unmixing used independently on each structure of the sampled batch
+ unmixings[batch_idx], Cs_embedded[batch_idx], Ts[batch_idx], current_loss = gromov_wasserstein_linear_unmixing(
+ Cs[C_idx], Cdict, reg=reg, p=ps[C_idx], q=q, tol_outer=tol_outer, tol_inner=tol_inner,
+ max_iter_outer=max_iter_outer, max_iter_inner=max_iter_inner
+ )
+ cumulated_loss_over_batch += current_loss
+ cumulated_loss_over_epoch += cumulated_loss_over_batch
+
+ if use_log:
+ log['loss_batches'].append(cumulated_loss_over_batch)
+
+ # Stochastic projected gradient step over dictionary atoms
+ grad_Cdict = np.zeros_like(Cdict)
+ for batch_idx, C_idx in enumerate(batch):
+ shared_term_structures = Cs_embedded[batch_idx] * const_q - (Cs[C_idx].dot(Ts[batch_idx])).T.dot(Ts[batch_idx])
+ grad_Cdict += unmixings[batch_idx][:, None, None] * shared_term_structures[None, :, :]
+ grad_Cdict *= 2 / batch_size
+ if use_adam_optimizer:
+ Cdict, adam_moments = _adam_stochastic_updates(Cdict, grad_Cdict, learning_rate, adam_moments)
+ else:
+ Cdict -= learning_rate * grad_Cdict
+ if 'symmetric' in projection:
+ Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1)))
+ if 'nonnegative' in projection:
+ Cdict[Cdict < 0.] = 0.
+
+ if use_log:
+ log['loss_epochs'].append(cumulated_loss_over_epoch)
+ if loss_best_state > cumulated_loss_over_epoch:
+ loss_best_state = cumulated_loss_over_epoch
+ Cdict_best_state = Cdict.copy()
+ if verbose:
+ print('--- epoch =', epoch, ' cumulated reconstruction error: ', cumulated_loss_over_epoch)
+
+ return nx.from_numpy(Cdict_best_state), log
+
+
+def _initialize_adam_optimizer(variable):
+
+ # Initialization for our numpy implementation of adam optimizer
+ atoms_adam_m = np.zeros_like(variable) # Initialize first moment tensor
+ atoms_adam_v = np.zeros_like(variable) # Initialize second moment tensor
+ atoms_adam_count = 1
+
+ return {'mean': atoms_adam_m, 'var': atoms_adam_v, 'count': atoms_adam_count}
+
+
+def _adam_stochastic_updates(variable, grad, learning_rate, adam_moments, beta_1=0.9, beta_2=0.99, eps=1e-09):
+
+ adam_moments['mean'] = beta_1 * adam_moments['mean'] + (1 - beta_1) * grad
+ adam_moments['var'] = beta_2 * adam_moments['var'] + (1 - beta_2) * (grad**2)
+ unbiased_m = adam_moments['mean'] / (1 - beta_1**adam_moments['count'])
+ unbiased_v = adam_moments['var'] / (1 - beta_2**adam_moments['count'])
+ variable -= learning_rate * unbiased_m / (np.sqrt(unbiased_v) + eps)
+ adam_moments['count'] += 1
+
+ return variable, adam_moments
+
+
+def gromov_wasserstein_linear_unmixing(C, Cdict, reg=0., p=None, q=None, tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, **kwargs):
+ r"""
+ Returns the Gromov-Wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary :math:`\{ (\mathbf{C_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}`.
+
+ .. math::
+ \min_{ \mathbf{w}} GW_2(\mathbf{C}, \sum_{d=1}^D w_d\mathbf{C_{dict}[d]}, \mathbf{p}, \mathbf{q}) - reg \| \mathbf{w} \|_2^2
+
+ such that:
+
+ - :math:`\mathbf{w}^\top \mathbf{1}_D = 1`
+ - :math:`\mathbf{w} \geq \mathbf{0}_D`
+
+ Where :
+
+ - :math:`\mathbf{C}` is the (ns,ns) pairwise similarity matrix.
+ - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrices of size nt.
+ - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights.
+ - reg is the regularization coefficient.
+
+ The algorithm used for solving the problem is a Block Coordinate Descent as discussed in [38], algorithm 1.
+
+ Parameters
+ ----------
+ C : array-like, shape (ns, ns)
+ Metric/Graph cost matrix.
+ Cdict : D array-like, shape (D,nt,nt)
+ Metric/Graph cost matrices composing the dictionary on which to embed C.
+ reg : float, optional.
+ Coefficient of the negative quadratic regularization used to promote sparsity of w. Default is 0.
+ p : array-like, shape (ns,), optional
+ Distribution in the source space C. Default is None and corresponds to uniform distribution.
+ q : array-like, shape (nt,), optional
+ Distribution in the space depicted by the dictionary. Default is None and corresponds to uniform distribution.
+ tol_outer : float, optional
+ Solver precision for the BCD algorithm.
+ tol_inner : float, optional
+ Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport. Default is :math:`10^{-5}`.
+ max_iter_outer : int, optional
+ Maximum number of iterations for the BCD. Default is 20.
+ max_iter_inner : int, optional
+ Maximum number of iterations for the Conjugate Gradient. Default is 200.
+
+ Returns
+ -------
+ w: array-like, shape (D,)
+ gromov-wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the span of the dictionary.
+ Cembedded: array-like, shape (nt,nt)
+ embedded structure of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{C_{dict}[d]}`.
+ T: array-like (ns, nt)
+ Gromov-Wasserstein transport plan between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w_d\mathbf{C_{dict}[d]}, \mathbf{q})`
+ current_loss: float
+ reconstruction error
+ References
+ -------
+
+ ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty.
+ "Online Graph Dictionary Learning"
+ International Conference on Machine Learning (ICML). 2021.
+ """
+ C0, Cdict0 = C, Cdict
+ nx = get_backend(C0, Cdict0)
+ C = nx.to_numpy(C0)
+ Cdict = nx.to_numpy(Cdict0)
+ if p is None:
+ p = unif(C.shape[0])
+ else:
+ p = nx.to_numpy(p)
+
+ if q is None:
+ q = unif(Cdict.shape[-1])
+ else:
+ q = nx.to_numpy(q)
+
+ T = p[:, None] * q[None, :]
+ D = len(Cdict)
+
+ w = unif(D) # Initialize uniformly the unmixing w
+ Cembedded = np.sum(w[:, None, None] * Cdict, axis=0)
+
+ const_q = q[:, None] * q[None, :]
+ # Trackers for BCD convergence
+ convergence_criterion = np.inf
+ current_loss = 10**15
+ outer_count = 0
+
+ while (convergence_criterion > tol_outer) and (outer_count < max_iter_outer):
+ previous_loss = current_loss
+ # 1. Solve GW transport between (C,p) and (\sum_d Cdictionary[d],q) fixing the unmixing w
+ T, log = gromov_wasserstein(C1=C, C2=Cembedded, p=p, q=q, loss_fun='square_loss', G0=T, log=True, armijo=False, **kwargs)
+ current_loss = log['gw_dist']
+ if reg != 0:
+ current_loss -= reg * np.sum(w**2)
+
+ # 2. Solve linear unmixing problem over w with a fixed transport plan T
+ w, Cembedded, current_loss = _cg_gromov_wasserstein_unmixing(
+ C=C, Cdict=Cdict, Cembedded=Cembedded, w=w, const_q=const_q, T=T,
+ starting_loss=current_loss, reg=reg, tol=tol_inner, max_iter=max_iter_inner, **kwargs
+ )
+
+ if previous_loss != 0:
+ convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss)
+ else: # handle numerical issues around 0
+ convergence_criterion = abs(previous_loss - current_loss) / 10**(-15)
+ outer_count += 1
+
+ return nx.from_numpy(w), nx.from_numpy(Cembedded), nx.from_numpy(T), nx.from_numpy(current_loss)
+
+
+def _cg_gromov_wasserstein_unmixing(C, Cdict, Cembedded, w, const_q, T, starting_loss, reg=0., tol=10**(-5), max_iter=200, **kwargs):
+ r"""
+ Returns for a fixed admissible transport plan,
+ the linear unmixing w minimizing the Gromov-Wasserstein cost between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w[d]*\mathbf{C_{dict}[d]}, \mathbf{q})`
+
+ .. math::
+ \min_{\mathbf{w}} \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D w_d*C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} - reg* \| \mathbf{w} \|_2^2
+
+
+ Such that:
+
+ - :math:`\mathbf{w}^\top \mathbf{1}_D = 1`
+ - :math:`\mathbf{w} \geq \mathbf{0}_D`
+
+ Where :
+
+ - :math:`\mathbf{C}` is the (ns,ns) pairwise similarity matrix.
+ - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrices of nt points.
+ - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights.
+ - :math:`\mathbf{w}` is the linear unmixing of :math:`(\mathbf{C}, \mathbf{p})` onto :math:`(\sum_d w_d \mathbf{Cdict[d]}, \mathbf{q})`.
+ - :math:`\mathbf{T}` is the optimal transport plan conditioned by the current state of :math:`\mathbf{w}`.
+ - reg is the regularization coefficient.
+
+ The algorithm used for solving the problem is a Conditional Gradient Descent as discussed in [38]
+
+ Parameters
+ ----------
+
+ C : array-like, shape (ns, ns)
+ Metric/Graph cost matrix.
+ Cdict : list of D array-like, shape (nt,nt)
+ Metric/Graph cost matrices composing the dictionary on which to embed C.
+ Each matrix in the dictionary must have the same size (nt,nt).
+ Cembedded: array-like, shape (nt,nt)
+ Embedded structure :math:`(\sum_d w[d]*Cdict[d],q)` of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary. Used to avoid redundant computations.
+ w: array-like, shape (D,)
+ Linear unmixing of the input structure onto the dictionary
+ const_q: array-like, shape (nt,nt)
+ product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations.
+ T: array-like, shape (ns,nt)
+ fixed transport plan between the input structure and its representation in the dictionary.
+ p : array-like, shape (ns,)
+ Distribution in the source space.
+ q : array-like, shape (nt,)
+ Distribution in the embedding space depicted by the dictionary.
+ reg : float, optional.
+ Coefficient of the negative quadratic regularization used to promote sparsity of w. Default is 0.
+
+ Returns
+ -------
+ w: ndarray (D,)
+ optimal unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary span given OT starting from previously optimal unmixing.
+ """
+ convergence_criterion = np.inf
+ current_loss = starting_loss
+ count = 0
+ const_TCT = np.transpose(C.dot(T)).dot(T)
+
+ while (convergence_criterion > tol) and (count < max_iter):
+
+ previous_loss = current_loss
+ # 1) Compute gradient at current point w
+ grad_w = 2 * np.sum(Cdict * (Cembedded[None, :, :] * const_q[None, :, :] - const_TCT[None, :, :]), axis=(1, 2))
+ grad_w -= 2 * reg * w
+
+ # 2) Conditional gradient direction finding: x= \argmin_x x^T.grad_w
+ min_ = np.min(grad_w)
+ x = (grad_w == min_).astype(np.float64)
+ x /= np.sum(x)
+
+ # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c
+ gamma, a, b, Cembedded_diff = _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, const_q, const_TCT, reg)
+
+ # 4) Updates: w <-- (1-gamma)*w + gamma*x
+ w += gamma * (x - w)
+ Cembedded += gamma * Cembedded_diff
+ current_loss += a * (gamma**2) + b * gamma
+
+ if previous_loss != 0: # not that the loss can be negative if reg >0
+ convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss)
+ else: # handle numerical issues around 0
+ convergence_criterion = abs(previous_loss - current_loss) / 10**(-15)
+ count += 1
+
+ return w, Cembedded, current_loss
+
+
+def _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, const_q, const_TCT, reg, **kwargs):
+ r"""
+ Compute optimal steps for the line search problem of Gromov-Wasserstein linear unmixing
+ .. math::
+ \min_{\gamma \in [0,1]} \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D z_d(\gamma)C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} - reg\| \mathbf{z}(\gamma) \|_2^2
+
+
+ Such that:
+
+ - :math:`\mathbf{z}(\gamma) = (1- \gamma)\mathbf{w} + \gamma \mathbf{x}`
+
+ Parameters
+ ----------
+
+ w : array-like, shape (D,)
+ Unmixing.
+ grad_w : array-like, shape (D, D)
+ Gradient of the reconstruction loss with respect to w.
+ x: array-like, shape (D,)
+ Conditional gradient direction.
+ Cdict : list of D array-like, shape (nt,nt)
+ Metric/Graph cost matrices composing the dictionary on which to embed C.
+ Each matrix in the dictionary must have the same size (nt,nt).
+ Cembedded: array-like, shape (nt,nt)
+ Embedded structure :math:`(\sum_d w_dCdict[d],q)` of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary. Used to avoid redundant computations.
+ const_q: array-like, shape (nt,nt)
+ product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations.
+ const_TCT: array-like, shape (nt, nt)
+ :math:`\mathbf{T}^\top \mathbf{C}^\top \mathbf{T}`. Used to avoid redundant computations.
+ Returns
+ -------
+ gamma: float
+ Optimal value for the line-search step
+ a: float
+ Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss
+ b: float
+ Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss
+ Cembedded_diff: numpy array, shape (nt, nt)
+ Difference between models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`.
+ reg : float, optional.
+ Coefficient of the negative quadratic regularization used to promote sparsity of :math:`\mathbf{w}`.
+ """
+
+ # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c
+ Cembedded_x = np.sum(x[:, None, None] * Cdict, axis=0)
+ Cembedded_diff = Cembedded_x - Cembedded
+ trace_diffx = np.sum(Cembedded_diff * Cembedded_x * const_q)
+ trace_diffw = np.sum(Cembedded_diff * Cembedded * const_q)
+ a = trace_diffx - trace_diffw
+ b = 2 * (trace_diffw - np.sum(Cembedded_diff * const_TCT))
+ if reg != 0:
+ a -= reg * np.sum((x - w)**2)
+ b -= 2 * reg * np.sum(w * (x - w))
+
+ if a > 0:
+ gamma = min(1, max(0, - b / (2 * a)))
+ elif a + b < 0:
+ gamma = 1
+ else:
+ gamma = 0
+
+ return gamma, a, b, Cembedded_diff
+
+
+def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate_C=1., learning_rate_Y=1.,
+ Cdict_init=None, Ydict_init=None, projection='nonnegative_symmetric', use_log=False,
+ tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs):
+ r"""
+ Infer Fused Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, \mathbf{Y_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}` from the list of S attributed structures :math:`\{ (\mathbf{C_s}, \mathbf{Y_s},\mathbf{p_s}) \}_s`
+
+ .. math::
+ \min_{\mathbf{C_{dict}},\mathbf{Y_{dict}}, \{\mathbf{w_s}\}_{s}} \sum_{s=1}^S FGW_{2,\alpha}(\mathbf{C_s}, \mathbf{Y_s}, \sum_{d=1}^D w_{s,d}\mathbf{C_{dict}[d]},\sum_{d=1}^D w_{s,d}\mathbf{Y_{dict}[d]}, \mathbf{p_s}, \mathbf{q}) \\ - reg\| \mathbf{w_s} \|_2^2
+
+
+ Such that :math:`\forall s \leq S` :
+
+ - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1`
+ - :math:`\mathbf{w_s} \geq \mathbf{0}_D`
+
+ Where :
+
+ - :math:`\forall s \leq S, \mathbf{C_s}` is a (ns,ns) pairwise similarity matrix of variable size ns.
+ - :math:`\forall s \leq S, \mathbf{Y_s}` is a (ns,d) features matrix of variable size ns and fixed dimension d.
+ - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt.
+ - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d.
+ - :math:`\forall s \leq S, \mathbf{p_s}` is the source distribution corresponding to :math:`\mathbf{C_s}`
+ - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space.
+ - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein
+ - reg is the regularization coefficient.
+
+
+ The stochastic algorithm used for estimating the attributed graph dictionary atoms as proposed in [38]
+
+ Parameters
+ ----------
+ Cs : list of S symmetric array-like, shape (ns, ns)
+ List of Metric/Graph cost matrices of variable size (ns,ns).
+ Ys : list of S array-like, shape (ns, d)
+ List of feature matrix of variable size (ns,d) with d fixed.
+ D: int
+ Number of dictionary atoms to learn
+ nt: int
+ Number of samples within each dictionary atoms
+ alpha : float
+ Trade-off parameter of Fused Gromov-Wasserstein
+ reg : float, optional
+ Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0.
+ ps : list of S array-like, shape (ns,), optional
+ Distribution in each source space C of Cs. Default is None and corresponds to uniform distibutions.
+ q : array-like, shape (nt,), optional
+ Distribution in the embedding space whose structure will be learned. Default is None and corresponds to uniform distributions.
+ epochs: int, optional
+ Number of epochs used to learn the dictionary. Default is 32.
+ batch_size: int, optional
+ Batch size for each stochastic gradient update of the dictionary. Set to the dataset size if the provided batch_size is higher than the dataset size. Default is 32.
+ learning_rate_C: float, optional
+ Learning rate used for the stochastic gradient descent on Cdict. Default is 1.
+ learning_rate_Y: float, optional
+ Learning rate used for the stochastic gradient descent on Ydict. Default is 1.
+ Cdict_init: list of D array-like with shape (nt, nt), optional
+ Used to initialize the dictionary structures Cdict.
+ If set to None (Default), the dictionary will be initialized randomly.
+ Else Cdict must have shape (D, nt, nt) i.e match provided shape features.
+ Ydict_init: list of D array-like with shape (nt, d), optional
+ Used to initialize the dictionary features Ydict.
+ If set to None, the dictionary features will be initialized randomly.
+ Else Ydict must have shape (D, nt, d) where d is the features dimension of inputs Ys and also match provided shape features.
+ projection: str, optional
+ If 'nonnegative' and/or 'symmetric' is in projection, the corresponding projection will be performed at each stochastic update of the dictionary
+ Else the set of atoms is :math:`R^{nt * nt}`. Default is 'nonnegative_symmetric'
+ log: bool, optional
+ If set to True, losses evolution by batches and epochs are tracked. Default is False.
+ use_adam_optimizer: bool, optional
+ If set to True, adam optimizer with default settings is used as adaptative learning rate strategy.
+ Else perform SGD with fixed learning rate. Default is True.
+ tol_outer : float, optional
+ Solver precision for the BCD algorithm, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`.
+ tol_inner : float, optional
+ Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`.
+ max_iter_outer : int, optional
+ Maximum number of iterations for the BCD. Default is 20.
+ max_iter_inner : int, optional
+ Maximum number of iterations for the Conjugate Gradient. Default is 200.
+ verbose : bool, optional
+ Print the reconstruction loss every epoch. Default is False.
+
+ Returns
+ -------
+
+ Cdict_best_state : D array-like, shape (D,nt,nt)
+ Metric/Graph cost matrices composing the dictionary.
+ The dictionary leading to the best loss over an epoch is saved and returned.
+ Ydict_best_state : D array-like, shape (D,nt,d)
+ Feature matrices composing the dictionary.
+ The dictionary leading to the best loss over an epoch is saved and returned.
+ log: dict
+ If use_log is True, contains loss evolutions by batches and epoches.
+ References
+ -------
+
+ ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty.
+ "Online Graph Dictionary Learning"
+ International Conference on Machine Learning (ICML). 2021.
+ """
+ Cs0, Ys0 = Cs, Ys
+ nx = get_backend(*Cs0, *Ys0)
+ Cs = [nx.to_numpy(C) for C in Cs0]
+ Ys = [nx.to_numpy(Y) for Y in Ys0]
+
+ d = Ys[0].shape[-1]
+ dataset_size = len(Cs)
+
+ if ps is None:
+ ps = [unif(C.shape[0]) for C in Cs]
+ else:
+ ps = [nx.to_numpy(p) for p in ps]
+ if q is None:
+ q = unif(nt)
+ else:
+ q = nx.to_numpy(q)
+
+ if Cdict_init is None:
+ # Initialize randomly structures of dictionary atoms based on samples
+ dataset_means = [C.mean() for C in Cs]
+ Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt))
+ else:
+ Cdict = nx.to_numpy(Cdict_init).copy()
+ assert Cdict.shape == (D, nt, nt)
+ if Ydict_init is None:
+ # Initialize randomly features of dictionary atoms based on samples distribution by feature component
+ dataset_feature_means = np.stack([F.mean(axis=0) for F in Ys])
+ Ydict = np.random.normal(loc=dataset_feature_means.mean(axis=0), scale=dataset_feature_means.std(axis=0), size=(D, nt, d))
+ else:
+ Ydict = nx.to_numpy(Ydict_init).copy()
+ assert Ydict.shape == (D, nt, d)
+
+ if 'symmetric' in projection:
+ Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1)))
+ if 'nonnegative' in projection:
+ Cdict[Cdict < 0.] = 0.
+
+ if use_adam_optimizer:
+ adam_moments_C = _initialize_adam_optimizer(Cdict)
+ adam_moments_Y = _initialize_adam_optimizer(Ydict)
+
+ log = {'loss_batches': [], 'loss_epochs': []}
+ const_q = q[:, None] * q[None, :]
+ diag_q = np.diag(q)
+ Cdict_best_state = Cdict.copy()
+ Ydict_best_state = Ydict.copy()
+ loss_best_state = np.inf
+ if batch_size > dataset_size:
+ batch_size = dataset_size
+ iter_by_epoch = dataset_size // batch_size + int((dataset_size % batch_size) > 0)
+
+ for epoch in range(epochs):
+ cumulated_loss_over_epoch = 0.
+
+ for _ in range(iter_by_epoch):
+
+ # Batch iterations
+ batch = np.random.choice(range(dataset_size), size=batch_size, replace=False)
+ cumulated_loss_over_batch = 0.
+ unmixings = np.zeros((batch_size, D))
+ Cs_embedded = np.zeros((batch_size, nt, nt))
+ Ys_embedded = np.zeros((batch_size, nt, d))
+ Ts = [None] * batch_size
+
+ for batch_idx, C_idx in enumerate(batch):
+ # BCD solver for Gromov-Wassersteisn linear unmixing used independently on each structure of the sampled batch
+ unmixings[batch_idx], Cs_embedded[batch_idx], Ys_embedded[batch_idx], Ts[batch_idx], current_loss = fused_gromov_wasserstein_linear_unmixing(
+ Cs[C_idx], Ys[C_idx], Cdict, Ydict, alpha, reg=reg, p=ps[C_idx], q=q,
+ tol_outer=tol_outer, tol_inner=tol_inner, max_iter_outer=max_iter_outer, max_iter_inner=max_iter_inner
+ )
+ cumulated_loss_over_batch += current_loss
+ cumulated_loss_over_epoch += cumulated_loss_over_batch
+ if use_log:
+ log['loss_batches'].append(cumulated_loss_over_batch)
+
+ # Stochastic projected gradient step over dictionary atoms
+ grad_Cdict = np.zeros_like(Cdict)
+ grad_Ydict = np.zeros_like(Ydict)
+
+ for batch_idx, C_idx in enumerate(batch):
+ shared_term_structures = Cs_embedded[batch_idx] * const_q - (Cs[C_idx].dot(Ts[batch_idx])).T.dot(Ts[batch_idx])
+ shared_term_features = diag_q.dot(Ys_embedded[batch_idx]) - Ts[batch_idx].T.dot(Ys[C_idx])
+ grad_Cdict += alpha * unmixings[batch_idx][:, None, None] * shared_term_structures[None, :, :]
+ grad_Ydict += (1 - alpha) * unmixings[batch_idx][:, None, None] * shared_term_features[None, :, :]
+ grad_Cdict *= 2 / batch_size
+ grad_Ydict *= 2 / batch_size
+
+ if use_adam_optimizer:
+ Cdict, adam_moments_C = _adam_stochastic_updates(Cdict, grad_Cdict, learning_rate_C, adam_moments_C)
+ Ydict, adam_moments_Y = _adam_stochastic_updates(Ydict, grad_Ydict, learning_rate_Y, adam_moments_Y)
+ else:
+ Cdict -= learning_rate_C * grad_Cdict
+ Ydict -= learning_rate_Y * grad_Ydict
+
+ if 'symmetric' in projection:
+ Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1)))
+ if 'nonnegative' in projection:
+ Cdict[Cdict < 0.] = 0.
+
+ if use_log:
+ log['loss_epochs'].append(cumulated_loss_over_epoch)
+ if loss_best_state > cumulated_loss_over_epoch:
+ loss_best_state = cumulated_loss_over_epoch
+ Cdict_best_state = Cdict.copy()
+ Ydict_best_state = Ydict.copy()
+ if verbose:
+ print('--- epoch: ', epoch, ' cumulated reconstruction error: ', cumulated_loss_over_epoch)
+
+ return nx.from_numpy(Cdict_best_state), nx.from_numpy(Ydict_best_state), log
+
+
+def fused_gromov_wasserstein_linear_unmixing(C, Y, Cdict, Ydict, alpha, reg=0., p=None, q=None, tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, **kwargs):
+ r"""
+ Returns the Fused Gromov-Wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` onto the attributed dictionary atoms :math:`\{ (\mathbf{C_{dict}[d]},\mathbf{Y_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}`
+
+ .. math::
+ \min_{\mathbf{w}} FGW_{2,\alpha}(\mathbf{C},\mathbf{Y}, \sum_{d=1}^D w_d\mathbf{C_{dict}[d]},\sum_{d=1}^D w_d\mathbf{Y_{dict}[d]}, \mathbf{p}, \mathbf{q}) - reg \| \mathbf{w} \|_2^2
+
+ such that, :math:`\forall s \leq S` :
+
+ - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1`
+ - :math:`\mathbf{w_s} \geq \mathbf{0}_D`
+
+ Where :
+
+ - :math:`\mathbf{C}` is a (ns,ns) pairwise similarity matrix of variable size ns.
+ - :math:`\mathbf{Y}` is a (ns,d) features matrix of variable size ns and fixed dimension d.
+ - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt.
+ - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d.
+ - :math:`\mathbf{p}` is the source distribution corresponding to :math:`\mathbf{C_s}`
+ - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space.
+ - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein
+ - reg is the regularization coefficient.
+
+ The algorithm used for solving the problem is a Block Coordinate Descent as discussed in [38], algorithm 6.
+
+ Parameters
+ ----------
+ C : array-like, shape (ns, ns)
+ Metric/Graph cost matrix.
+ Y : array-like, shape (ns, d)
+ Feature matrix.
+ Cdict : D array-like, shape (D,nt,nt)
+ Metric/Graph cost matrices composing the dictionary on which to embed (C,Y).
+ Ydict : D array-like, shape (D,nt,d)
+ Feature matrices composing the dictionary on which to embed (C,Y).
+ alpha: float,
+ Trade-off parameter of Fused Gromov-Wasserstein.
+ reg : float, optional
+ Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0.
+ p : array-like, shape (ns,), optional
+ Distribution in the source space C. Default is None and corresponds to uniform distribution.
+ q : array-like, shape (nt,), optional
+ Distribution in the space depicted by the dictionary. Default is None and corresponds to uniform distribution.
+ tol_outer : float, optional
+ Solver precision for the BCD algorithm.
+ tol_inner : float, optional
+ Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport. Default is :math:`10^{-5}`.
+ max_iter_outer : int, optional
+ Maximum number of iterations for the BCD. Default is 20.
+ max_iter_inner : int, optional
+ Maximum number of iterations for the Conjugate Gradient. Default is 200.
+
+ Returns
+ -------
+ w: array-like, shape (D,)
+ fused gromov-wasserstein linear unmixing of (C,Y,p) onto the span of the dictionary.
+ Cembedded: array-like, shape (nt,nt)
+ embedded structure of :math:`(\mathbf{C},\mathbf{Y}, \mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{C_{dict}[d]}`.
+ Yembedded: array-like, shape (nt,d)
+ embedded features of :math:`(\mathbf{C},\mathbf{Y}, \mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{Y_{dict}[d]}`.
+ T: array-like (ns,nt)
+ Fused Gromov-Wasserstein transport plan between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w_d\mathbf{C_{dict}[d]}, \sum_d w_d\mathbf{Y_{dict}[d]},\mathbf{q})`.
+ current_loss: float
+ reconstruction error
+ References
+ -------
+
+ ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty.
+ "Online Graph Dictionary Learning"
+ International Conference on Machine Learning (ICML). 2021.
+ """
+ C0, Y0, Cdict0, Ydict0 = C, Y, Cdict, Ydict
+ nx = get_backend(C0, Y0, Cdict0, Ydict0)
+ C = nx.to_numpy(C0)
+ Y = nx.to_numpy(Y0)
+ Cdict = nx.to_numpy(Cdict0)
+ Ydict = nx.to_numpy(Ydict0)
+
+ if p is None:
+ p = unif(C.shape[0])
+ else:
+ p = nx.to_numpy(p)
+ if q is None:
+ q = unif(Cdict.shape[-1])
+ else:
+ q = nx.to_numpy(q)
+
+ T = p[:, None] * q[None, :]
+ D = len(Cdict)
+ d = Y.shape[-1]
+ w = unif(D) # Initialize with uniform weights
+ ns = C.shape[-1]
+ nt = Cdict.shape[-1]
+
+ # modeling (C,Y)
+ Cembedded = np.sum(w[:, None, None] * Cdict, axis=0)
+ Yembedded = np.sum(w[:, None, None] * Ydict, axis=0)
+
+ # constants depending on q
+ const_q = q[:, None] * q[None, :]
+ diag_q = np.diag(q)
+ # Trackers for BCD convergence
+ convergence_criterion = np.inf
+ current_loss = 10**15
+ outer_count = 0
+ Ys_constM = (Y**2).dot(np.ones((d, nt))) # constant in computing euclidean pairwise feature matrix
+
+ while (convergence_criterion > tol_outer) and (outer_count < max_iter_outer):
+ previous_loss = current_loss
+
+ # 1. Solve GW transport between (C,p) and (\sum_d Cdictionary[d],q) fixing the unmixing w
+ Yt_varM = (np.ones((ns, d))).dot((Yembedded**2).T)
+ M = Ys_constM + Yt_varM - 2 * Y.dot(Yembedded.T) # euclidean distance matrix between features
+ T, log = fused_gromov_wasserstein(M, C, Cembedded, p, q, loss_fun='square_loss', alpha=alpha, armijo=False, G0=T, log=True)
+ current_loss = log['fgw_dist']
+ if reg != 0:
+ current_loss -= reg * np.sum(w**2)
+
+ # 2. Solve linear unmixing problem over w with a fixed transport plan T
+ w, Cembedded, Yembedded, current_loss = _cg_fused_gromov_wasserstein_unmixing(C, Y, Cdict, Ydict, Cembedded, Yembedded, w,
+ T, p, q, const_q, diag_q, current_loss, alpha, reg,
+ tol=tol_inner, max_iter=max_iter_inner, **kwargs)
+ if previous_loss != 0:
+ convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss)
+ else:
+ convergence_criterion = abs(previous_loss - current_loss) / 10**(-12)
+ outer_count += 1
+
+ return nx.from_numpy(w), nx.from_numpy(Cembedded), nx.from_numpy(Yembedded), nx.from_numpy(T), nx.from_numpy(current_loss)
+
+
+def _cg_fused_gromov_wasserstein_unmixing(C, Y, Cdict, Ydict, Cembedded, Yembedded, w, T, p, q, const_q, diag_q, starting_loss, alpha, reg, tol=10**(-6), max_iter=200, **kwargs):
+ r"""
+ Returns for a fixed admissible transport plan,
+ the optimal linear unmixing :math:`\mathbf{w}` minimizing the Fused Gromov-Wasserstein cost between :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` and :math:`(\sum_d w_d \mathbf{C_{dict}[d]},\sum_d w_d*\mathbf{Y_{dict}[d]}, \mathbf{q})`
+
+ .. math::
+ \min_{\mathbf{w}} \alpha \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D w_d C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} \\+ (1-\alpha) \sum_{ij} \| \mathbf{Y_i} - \sum_d w_d \mathbf{Y_{dict}[d]_j} \|_2^2 T_{ij}- reg \| \mathbf{w} \|_2^2
+
+ Such that :
+
+ - :math:`\mathbf{w}^\top \mathbf{1}_D = 1`
+ - :math:`\mathbf{w} \geq \mathbf{0}_D`
+
+ Where :
+
+ - :math:`\mathbf{C}` is a (ns,ns) pairwise similarity matrix of variable size ns.
+ - :math:`\mathbf{Y}` is a (ns,d) features matrix of variable size ns and fixed dimension d.
+ - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt.
+ - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d.
+ - :math:`\mathbf{p}` is the source distribution corresponding to :math:`\mathbf{C_s}`
+ - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space.
+ - :math:`\mathbf{T}` is the optimal transport plan conditioned by the previous state of :math:`\mathbf{w}`
+ - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein
+ - reg is the regularization coefficient.
+
+ The algorithm used for solving the problem is a Conditional Gradient Descent as discussed in [38], algorithm 7.
+
+ Parameters
+ ----------
+
+ C : array-like, shape (ns, ns)
+ Metric/Graph cost matrix.
+ Y : array-like, shape (ns, d)
+ Feature matrix.
+ Cdict : list of D array-like, shape (nt,nt)
+ Metric/Graph cost matrices composing the dictionary on which to embed (C,Y).
+ Each matrix in the dictionary must have the same size (nt,nt).
+ Ydict : list of D array-like, shape (nt,d)
+ Feature matrices composing the dictionary on which to embed (C,Y).
+ Each matrix in the dictionary must have the same size (nt,d).
+ Cembedded: array-like, shape (nt,nt)
+ Embedded structure of (C,Y) onto the dictionary
+ Yembedded: array-like, shape (nt,d)
+ Embedded features of (C,Y) onto the dictionary
+ w: array-like, shape (n_D,)
+ Linear unmixing of (C,Y) onto (Cdict,Ydict)
+ const_q: array-like, shape (nt,nt)
+ product matrix :math:`\mathbf{qq}^\top` where :math:`\mathbf{q}` is the target space distribution.
+ diag_q: array-like, shape (nt,nt)
+ diagonal matrix with values of q on the diagonal.
+ T: array-like, shape (ns,nt)
+ fixed transport plan between (C,Y) and its model
+ p : array-like, shape (ns,)
+ Distribution in the source space (C,Y).
+ q : array-like, shape (nt,)
+ Distribution in the embedding space depicted by the dictionary.
+ alpha: float,
+ Trade-off parameter of Fused Gromov-Wasserstein.
+ reg : float, optional
+ Coefficient of the negative quadratic regularization used to promote sparsity of w.
+
+ Returns
+ -------
+ w: ndarray (D,)
+ linear unmixing of :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` onto the span of :math:`(C_{dict},Y_{dict})` given OT corresponding to previous unmixing.
+ """
+ convergence_criterion = np.inf
+ current_loss = starting_loss
+ count = 0
+ const_TCT = np.transpose(C.dot(T)).dot(T)
+ ones_ns_d = np.ones(Y.shape)
+
+ while (convergence_criterion > tol) and (count < max_iter):
+ previous_loss = current_loss
+
+ # 1) Compute gradient at current point w
+ # structure
+ grad_w = alpha * np.sum(Cdict * (Cembedded[None, :, :] * const_q[None, :, :] - const_TCT[None, :, :]), axis=(1, 2))
+ # feature
+ grad_w += (1 - alpha) * np.sum(Ydict * (diag_q.dot(Yembedded)[None, :, :] - T.T.dot(Y)[None, :, :]), axis=(1, 2))
+ grad_w -= reg * w
+ grad_w *= 2
+
+ # 2) Conditional gradient direction finding: x= \argmin_x x^T.grad_w
+ min_ = np.min(grad_w)
+ x = (grad_w == min_).astype(np.float64)
+ x /= np.sum(x)
+
+ # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c
+ gamma, a, b, Cembedded_diff, Yembedded_diff = _linesearch_fused_gromov_wasserstein_unmixing(w, grad_w, x, Y, Cdict, Ydict, Cembedded, Yembedded, T, const_q, const_TCT, ones_ns_d, alpha, reg)
+
+ # 4) Updates: w <-- (1-gamma)*w + gamma*x
+ w += gamma * (x - w)
+ Cembedded += gamma * Cembedded_diff
+ Yembedded += gamma * Yembedded_diff
+ current_loss += a * (gamma**2) + b * gamma
+
+ if previous_loss != 0:
+ convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss)
+ else:
+ convergence_criterion = abs(previous_loss - current_loss) / 10**(-12)
+ count += 1
+
+ return w, Cembedded, Yembedded, current_loss
+
+
+def _linesearch_fused_gromov_wasserstein_unmixing(w, grad_w, x, Y, Cdict, Ydict, Cembedded, Yembedded, T, const_q, const_TCT, ones_ns_d, alpha, reg, **kwargs):
+ r"""
+ Compute optimal steps for the line search problem of Fused Gromov-Wasserstein linear unmixing
+ .. math::
+ \min_{\gamma \in [0,1]} \alpha \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D z_d(\gamma)C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} \\ + (1-\alpha) \sum_{ij} \| \mathbf{Y_i} - \sum_d z_d(\gamma) \mathbf{Y_{dict}[d]_j} \|_2^2 - reg\| \mathbf{z}(\gamma) \|_2^2
+
+
+ Such that :
+
+ - :math:`\mathbf{z}(\gamma) = (1- \gamma)\mathbf{w} + \gamma \mathbf{x}`
+
+ Parameters
+ ----------
+
+ w : array-like, shape (D,)
+ Unmixing.
+ grad_w : array-like, shape (D, D)
+ Gradient of the reconstruction loss with respect to w.
+ x: array-like, shape (D,)
+ Conditional gradient direction.
+ Y: arrat-like, shape (ns,d)
+ Feature matrix of the input space
+ Cdict : list of D array-like, shape (nt, nt)
+ Metric/Graph cost matrices composing the dictionary on which to embed (C,Y).
+ Each matrix in the dictionary must have the same size (nt,nt).
+ Ydict : list of D array-like, shape (nt, d)
+ Feature matrices composing the dictionary on which to embed (C,Y).
+ Each matrix in the dictionary must have the same size (nt,d).
+ Cembedded: array-like, shape (nt, nt)
+ Embedded structure of (C,Y) onto the dictionary
+ Yembedded: array-like, shape (nt, d)
+ Embedded features of (C,Y) onto the dictionary
+ T: array-like, shape (ns, nt)
+ Fixed transport plan between (C,Y) and its current model.
+ const_q: array-like, shape (nt,nt)
+ product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations.
+ const_TCT: array-like, shape (nt, nt)
+ :math:`\mathbf{T}^\top \mathbf{C}^\top \mathbf{T}`. Used to avoid redundant computations.
+ ones_ns_d: array-like, shape (ns, d)
+ :math:`\mathbf{1}_{ ns \times d}`. Used to avoid redundant computations.
+ alpha: float,
+ Trade-off parameter of Fused Gromov-Wasserstein.
+ reg : float, optional
+ Coefficient of the negative quadratic regularization used to promote sparsity of w.
+
+ Returns
+ -------
+ gamma: float
+ Optimal value for the line-search step
+ a: float
+ Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss
+ b: float
+ Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss
+ Cembedded_diff: numpy array, shape (nt, nt)
+ Difference between structure matrix of models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`.
+ Yembedded_diff: numpy array, shape (nt, nt)
+ Difference between feature matrix of models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`.
+ """
+ # polynomial coefficients from quadratic objective (with respect to w) on structures
+ Cembedded_x = np.sum(x[:, None, None] * Cdict, axis=0)
+ Cembedded_diff = Cembedded_x - Cembedded
+ trace_diffx = np.sum(Cembedded_diff * Cembedded_x * const_q)
+ trace_diffw = np.sum(Cembedded_diff * Cembedded * const_q)
+ # Constant factor appearing in the factorization a*gamma^2 + b*g + c of the Gromov-Wasserstein reconstruction loss
+ a_gw = trace_diffx - trace_diffw
+ b_gw = 2 * (trace_diffw - np.sum(Cembedded_diff * const_TCT))
+
+ # polynomial coefficient from quadratic objective (with respect to w) on features
+ Yembedded_x = np.sum(x[:, None, None] * Ydict, axis=0)
+ Yembedded_diff = Yembedded_x - Yembedded
+ # Constant factor appearing in the factorization a*gamma^2 + b*g + c of the Gromov-Wasserstein reconstruction loss
+ a_w = np.sum(ones_ns_d.dot((Yembedded_diff**2).T) * T)
+ b_w = 2 * np.sum(T * (ones_ns_d.dot((Yembedded * Yembedded_diff).T) - Y.dot(Yembedded_diff.T)))
+
+ a = alpha * a_gw + (1 - alpha) * a_w
+ b = alpha * b_gw + (1 - alpha) * b_w
+ if reg != 0:
+ a -= reg * np.sum((x - w)**2)
+ b -= 2 * reg * np.sum(w * (x - w))
+ if a > 0:
+ gamma = min(1, max(0, -b / (2 * a)))
+ elif a + b < 0:
+ gamma = 1
+ else:
+ gamma = 0
+
+ return gamma, a, b, Cembedded_diff, Yembedded_diff
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 5da897d..390c32d 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -26,6 +26,8 @@ from ..utils import dist, list_to_array from ..utils import parmap from ..backend import get_backend + + __all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted', 'emd_1d', 'emd2_1d', 'wasserstein_1d'] @@ -220,7 +222,15 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): format .. note:: This function is backend-compatible and will work on arrays - from all compatible backends. + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + + .. note:: This function will cast the computed transport plan to the data type + of the provided input with the following priority: :math:`\mathbf{a}`, + then :math:`\mathbf{b}`, then :math:`\mathbf{M}` if marginals are not provided. + Casting to an integer tensor might result in a loss of precision. + If this behaviour is unwanted, please make sure to provide a + floating point input. Uses the algorithm proposed in :ref:`[1] <references-emd>`. @@ -287,12 +297,16 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): a, b, M = list_to_array(a, b, M) a0, b0, M0 = a, b, M + if len(a0) != 0: + type_as = a0 + elif len(b0) != 0: + type_as = b0 + else: + type_as = M0 nx = get_backend(M0, a0, b0) # convert to numpy - M = nx.to_numpy(M) - a = nx.to_numpy(a) - b = nx.to_numpy(b) + M, a, b = nx.to_numpy(M, a, b) # ensure float64 a = np.asarray(a, dtype=np.float64) @@ -327,15 +341,23 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): u, v = estimate_dual_null_weights(u, v, a, b, M) result_code_string = check_result(result_code) + if not nx.is_floating_point(type_as): + warnings.warn( + "Input histogram consists of integer. The transport plan will be " + "casted accordingly, possibly resulting in a loss of precision. " + "If this behaviour is unwanted, please make sure your input " + "histogram consists of floating point elements.", + stacklevel=2 + ) if log: log = {} log['cost'] = cost - log['u'] = nx.from_numpy(u, type_as=a0) - log['v'] = nx.from_numpy(v, type_as=b0) + log['u'] = nx.from_numpy(u, type_as=type_as) + log['v'] = nx.from_numpy(v, type_as=type_as) log['warning'] = result_code_string log['result_code'] = result_code - return nx.from_numpy(G, type_as=M0), log - return nx.from_numpy(G, type_as=M0) + return nx.from_numpy(G, type_as=type_as), log + return nx.from_numpy(G, type_as=type_as) def emd2(a, b, M, processes=1, @@ -358,7 +380,16 @@ def emd2(a, b, M, processes=1, - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights .. note:: This function is backend-compatible and will work on arrays - from all compatible backends. + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + + .. note:: This function will cast the computed transport plan and + transportation loss to the data type of the provided input with the + following priority: :math:`\mathbf{a}`, then :math:`\mathbf{b}`, + then :math:`\mathbf{M}` if marginals are not provided. + Casting to an integer tensor might result in a loss of precision. + If this behaviour is unwanted, please make sure to provide a + floating point input. Uses the algorithm proposed in :ref:`[1] <references-emd2>`. @@ -428,12 +459,16 @@ def emd2(a, b, M, processes=1, a, b, M = list_to_array(a, b, M) a0, b0, M0 = a, b, M + if len(a0) != 0: + type_as = a0 + elif len(b0) != 0: + type_as = b0 + else: + type_as = M0 nx = get_backend(M0, a0, b0) # convert to numpy - M = nx.to_numpy(M) - a = nx.to_numpy(a) - b = nx.to_numpy(b) + M, a, b = nx.to_numpy(M, a, b) a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) @@ -466,15 +501,24 @@ def emd2(a, b, M, processes=1, result_code_string = check_result(result_code) log = {} - G = nx.from_numpy(G, type_as=M0) + if not nx.is_floating_point(type_as): + warnings.warn( + "Input histogram consists of integer. The transport plan will be " + "casted accordingly, possibly resulting in a loss of precision. " + "If this behaviour is unwanted, please make sure your input " + "histogram consists of floating point elements.", + stacklevel=2 + ) + G = nx.from_numpy(G, type_as=type_as) if return_matrix: log['G'] = G - log['u'] = nx.from_numpy(u, type_as=a0) - log['v'] = nx.from_numpy(v, type_as=b0) + log['u'] = nx.from_numpy(u, type_as=type_as) + log['v'] = nx.from_numpy(v, type_as=type_as) log['warning'] = result_code_string log['result_code'] = result_code - cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0), - (a0, b0, M0), (log['u'], log['v'], G)) + cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as), + (a0, b0, M0), (log['u'] - nx.mean(log['u']), + log['v'] - nx.mean(log['v']), G)) return [cost, log] else: def f(b): @@ -487,10 +531,18 @@ def emd2(a, b, M, processes=1, if np.any(~asel) or np.any(~bsel): u, v = estimate_dual_null_weights(u, v, a, b, M) - G = nx.from_numpy(G, type_as=M0) - cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0), - (a0, b0, M0), (nx.from_numpy(u, type_as=a0), - nx.from_numpy(v, type_as=b0), G)) + if not nx.is_floating_point(type_as): + warnings.warn( + "Input histogram consists of integer. The transport plan will be " + "casted accordingly, possibly resulting in a loss of precision. " + "If this behaviour is unwanted, please make sure your input " + "histogram consists of floating point elements.", + stacklevel=2 + ) + G = nx.from_numpy(G, type_as=type_as) + cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as), + (a0, b0, M0), (nx.from_numpy(u - np.mean(u), type_as=type_as), + nx.from_numpy(v - np.mean(v), type_as=type_as), G)) check_result(result_code) return cost @@ -535,18 +587,18 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None Parameters ---------- - measures_locations : list of N (k_i,d) numpy.ndarray + measures_locations : list of N (k_i,d) array-like The discrete support of a measure supported on :math:`k_i` locations of a `d`-dimensional space (:math:`k_i` can be different for each element of the list) - measures_weights : list of N (k_i,) numpy.ndarray + measures_weights : list of N (k_i,) array-like Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one representing the weights of each discrete input measure - X_init : (k,d) np.ndarray + X_init : (k,d) array-like Initialization of the support locations (on `k` atoms) of the barycenter - b : (k,) np.ndarray + b : (k,) array-like Initialization of the weights of the barycenter (non-negatives, sum to 1) - weights : (N,) np.ndarray + weights : (N,) array-like Initialization of the coefficients of the barycenter (non-negatives, sum to 1) numItermax : int, optional @@ -564,7 +616,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None Returns ------- - X : (k,d) np.ndarray + X : (k,d) array-like Support locations (on k atoms) of the barycenter @@ -577,15 +629,17 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None """ + nx = get_backend(*measures_locations,*measures_weights,X_init) + iter_count = 0 N = len(measures_locations) k = X_init.shape[0] d = X_init.shape[1] if b is None: - b = np.ones((k,)) / k + b = nx.ones((k,),type_as=X_init) / k if weights is None: - weights = np.ones((N,)) / N + weights = nx.ones((N,),type_as=X_init) / N X = X_init @@ -596,15 +650,15 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None while (displacement_square_norm > stopThr and iter_count < numItermax): - T_sum = np.zeros((k, d)) + T_sum = nx.zeros((k, d),type_as=X_init) + - for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, - weights.tolist()): + for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights): M_i = dist(X, measure_locations_i) T_i = emd(b, measure_weights_i, M_i, numThreads=numThreads) - T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i) + T_sum = T_sum + weight_i * 1. / b[:,None] * nx.dot(T_i, measure_locations_i) - displacement_square_norm = np.sum(np.square(T_sum - X)) + displacement_square_norm = nx.sum((T_sum - X)**2) if log: displacement_square_norms.append(displacement_square_norm) @@ -620,3 +674,4 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None return X, log_dict else: return X + diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py index 869d450..fbf3c0e 100644 --- a/ot/lp/cvx.py +++ b/ot/lp/cvx.py @@ -11,7 +11,6 @@ import numpy as np import scipy as sp import scipy.sparse as sps - try: import cvxopt from cvxopt import solvers, matrix, spmatrix diff --git a/ot/optim.py b/ot/optim.py index f25e2c9..5a1d605 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -9,12 +9,19 @@ Generic solvers for regularized OT # License: MIT License import numpy as np -from scipy.optimize.linesearch import scalar_search_armijo +import warnings from .lp import emd from .bregman import sinkhorn -from ot.utils import list_to_array +from .utils import list_to_array from .backend import get_backend +with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + from scipy.optimize import scalar_search_armijo + except ImportError: + from scipy.optimize.linesearch import scalar_search_armijo + # The corresponding scipy function does not work for matrices diff --git a/ot/partial.py b/ot/partial.py index b7093e4..0a9e450 100755 --- a/ot/partial.py +++ b/ot/partial.py @@ -7,7 +7,6 @@ Partial OT solvers # License: MIT License import numpy as np - from .lp import emd @@ -29,7 +28,8 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, \gamma &\geq 0 - \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} + \mathbf{1}^T \gamma^T \mathbf{1} = m & + \leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} or equivalently (see Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. @@ -50,7 +50,8 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, - :math:`\lambda` is the lagrangian cost. Tuning its value allows attaining a given mass to be transported `m` - The formulation of the problem has been proposed in :ref:`[28] <references-partial-wasserstein-lagrange>` + The formulation of the problem has been proposed in + :ref:`[28] <references-partial-wasserstein-lagrange>` Parameters @@ -261,7 +262,7 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): b_extended = np.append(b, [(np.sum(a) - m) / nb_dummies] * nb_dummies) a_extended = np.append(a, [(np.sum(b) - m) / nb_dummies] * nb_dummies) M_extended = np.zeros((len(a_extended), len(b_extended))) - M_extended[-nb_dummies:, -nb_dummies:] = np.max(M) * 1e5 + M_extended[-nb_dummies:, -nb_dummies:] = np.max(M) * 2 M_extended[:len(a), :len(b)] = M gamma, log_emd = emd(a_extended, b_extended, M_extended, log=True, @@ -455,7 +456,8 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights - `m` is the amount of mass to be transported - The formulation of the problem has been proposed in :ref:`[29] <references-partial-gromov-wasserstein>` + The formulation of the problem has been proposed in + :ref:`[29] <references-partial-gromov-wasserstein>` Parameters @@ -469,7 +471,8 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, q : ndarray, shape (nt,) Distribution in the target space m : float, optional - Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) + Amount of mass to be transported + (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) nb_dummies : int, optional Number of dummy points to add (avoid instabilities in the EMD solver) G0 : ndarray, shape (ns, nt), optional @@ -623,16 +626,19 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None, \gamma &\geq 0 - \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} + \mathbf{1}^T \gamma^T \mathbf{1} = m + &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} where : - :math:`\mathbf{M}` is the metric cost matrix - - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega` is the entropic regularization term, + :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights - `m` is the amount of mass to be transported - The formulation of the problem has been proposed in :ref:`[29] <references-partial-gromov-wasserstein2>` + The formulation of the problem has been proposed in + :ref:`[29] <references-partial-gromov-wasserstein2>` Parameters @@ -646,7 +652,8 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None, q : ndarray, shape (nt,) Distribution in the target space m : float, optional - Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) + Amount of mass to be transported + (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) nb_dummies : int, optional Number of dummy points to add (avoid instabilities in the EMD solver) G0 : ndarray, shape (ns, nt), optional @@ -728,21 +735,25 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000, The function considers the following problem: .. math:: - \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg} \cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, + \mathbf{M} \rangle_F + \mathrm{reg} \cdot\Omega(\gamma) s.t. \gamma \mathbf{1} &\leq \mathbf{a} \\ \gamma^T \mathbf{1} &\leq \mathbf{b} \\ \gamma &\geq 0 \\ - \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} \\ + \mathbf{1}^T \gamma^T \mathbf{1} = m + &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} \\ where : - :math:`\mathbf{M}` is the metric cost matrix - - :math:`\Omega` is the entropic regularization term, :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega` is the entropic regularization term, + :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights - `m` is the amount of mass to be transported - The formulation of the problem has been proposed in :ref:`[3] <references-entropic-partial-wasserstein>` (prop. 5) + The formulation of the problem has been proposed in + :ref:`[3] <references-entropic-partial-wasserstein>` (prop. 5) Parameters @@ -829,12 +840,23 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000, np.multiply(K, m / np.sum(K), out=K) err, cpt = 1, 0 + q1 = np.ones(K.shape) + q2 = np.ones(K.shape) + q3 = np.ones(K.shape) while (err > stopThr and cpt < numItermax): Kprev = K + K = K * q1 K1 = np.dot(np.diag(np.minimum(a / np.sum(K, axis=1), dx)), K) + q1 = q1 * Kprev / K1 + K1prev = K1 + K1 = K1 * q2 K2 = np.dot(K1, np.diag(np.minimum(b / np.sum(K1, axis=0), dy))) + q2 = q2 * K1prev / K2 + K2prev = K2 + K2 = K2 * q3 K = K2 * (m / np.sum(K2)) + q3 = q3 * K2prev / K if np.any(np.isnan(K)) or np.any(np.isinf(K)): print('Warning: numerical errors at iteration', cpt) @@ -861,7 +883,8 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None, numItermax=1000, tol=1e-7, log=False, verbose=False): r""" - Returns the partial Gromov-Wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` + Returns the partial Gromov-Wasserstein transport between + :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` The function solves the following optimization problem: @@ -877,7 +900,8 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None, \gamma^T \mathbf{1} &\leq \mathbf{b} - \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} + \mathbf{1}^T \gamma^T \mathbf{1} = m + &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} where : @@ -885,10 +909,13 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None, - :math:`\mathbf{C_2}` is the metric cost matrix in the target space - :math:`\mathbf{p}` and :math:`\mathbf{q}` are the sample weights - `L`: quadratic loss function - - :math:`\Omega` is the entropic regularization term, :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega` is the entropic regularization term, + :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - `m` is the amount of mass to be transported - The formulation of the GW problem has been proposed in :ref:`[12] <references-entropic-partial-gromov-wassertein>` and the partial GW in :ref:`[29] <references-entropic-partial-gromov-wassertein>` + The formulation of the GW problem has been proposed in + :ref:`[12] <references-entropic-partial-gromov-wassertein>` and the + partial GW in :ref:`[29] <references-entropic-partial-gromov-wassertein>` Parameters ---------- @@ -903,7 +930,8 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None, reg: float entropic regularization parameter m : float, optional - Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) + Amount of mass to be transported (default: + :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) G0 : ndarray, shape (ns, nt), optional Initialisation of the transportation matrix numItermax : int, optional @@ -1005,13 +1033,15 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None, numItermax=1000, tol=1e-7, log=False, verbose=False): r""" - Returns the partial Gromov-Wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` + Returns the partial Gromov-Wasserstein discrepancy between + :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` The function solves the following optimization problem: .. math:: - GW = \min_{\gamma} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l})\cdot - \gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma) + GW = \min_{\gamma} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, + \mathbf{C_2}_{j,l})\cdot + \gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma) .. math:: s.t. \ \gamma &\geq 0 @@ -1028,10 +1058,13 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None, - :math:`\mathbf{C_2}` is the metric cost matrix in the target space - :math:`\mathbf{p}` and :math:`\mathbf{q}` are the sample weights - `L` : quadratic loss function - - :math:`\Omega` is the entropic regularization term, :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega` is the entropic regularization term, + :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - `m` is the amount of mass to be transported - The formulation of the GW problem has been proposed in :ref:`[12] <references-entropic-partial-gromov-wassertein2>` and the partial GW in :ref:`[29] <references-entropic-partial-gromov-wassertein2>` + The formulation of the GW problem has been proposed in + :ref:`[12] <references-entropic-partial-gromov-wassertein2>` and the + partial GW in :ref:`[29] <references-entropic-partial-gromov-wassertein2>` Parameters @@ -1047,7 +1080,8 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None, reg: float entropic regularization parameter m : float, optional - Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) + Amount of mass to be transported (default: + :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) G0 : ndarray, shape (ns, nt), optional Initialisation of the transportation matrix numItermax : int, optional @@ -85,8 +85,13 @@ def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs): if ('color' not in kwargs) and ('c' not in kwargs): kwargs['color'] = 'k' mx = G.max() + if 'alpha' in kwargs: + scale = kwargs['alpha'] + del kwargs['alpha'] + else: + scale = 1 for i in range(xs.shape[0]): for j in range(xt.shape[0]): if G[i, j] / mx > thr: pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], - alpha=G[i, j] / mx, **kwargs) + alpha=G[i, j] / mx * scale, **kwargs) diff --git a/ot/regpath.py b/ot/regpath.py index 269937a..e745288 100644 --- a/ot/regpath.py +++ b/ot/regpath.py @@ -11,34 +11,48 @@ import scipy.sparse as sp def recast_ot_as_lasso(a, b, C): - r"""This function recasts the l2-penalized UOT problem as a Lasso problem + r"""This function recasts the l2-penalized UOT problem as a Lasso problem. + + Recall the l2-penalized UOT problem defined in + :ref:`[41] <references-regpath>` - Recall the l2-penalized UOT problem defined in [Chapel et al., 2021] .. math:: - UOT = \min_T <C, T> + \lambda \|T 1_m - a\|_2^2 + - \lambda \|T^T 1_n - b\|_2^2 + \text{UOT}_{\lambda} = \min_T <C, T> + \lambda \|T 1_m - + \mathbf{a}\|_2^2 + + \lambda \|T^T 1_n - \mathbf{b}\|_2^2 + s.t. T \geq 0 + where : - - C is the (dim_a, dim_b) metric cost matrix - - :math:`\lambda` is the l2-regularization coefficient - - a and b are source and target distributions - - T is the transport plan to optimize - The problem above can be reformulated to a non-negative penalized + - :math:`C` is the cost matrix + - :math:`\lambda` is the l2-regularization parameter + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the source and target \ + distributions + - :math:`T` is the transport plan to optimize + + The problem above can be reformulated as a non-negative penalized linear regression problem, particularly Lasso + .. math:: - UOT2 = \min_t \gamma c^T t + 0.5 * \|H t - y\|_2^2 + \text{UOT2}_{\lambda} = \min_{\mathbf{t}} \gamma \mathbf{c}^T + \mathbf{t} + 0.5 * \|H \mathbf{t} - \mathbf{y}\|_2^2 + s.t. - t \geq 0 + \mathbf{t} \geq 0 + where : - - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C) - - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient - - y is the concatenation of vectors a and b, defined as y^T = [a^T b^T] - - H is a (dim_a + dim_b, dim_a * dim_b) metric matrix, - see [Chapel et al., 2021] for the design of H. The matrix product H t - computes both the source marginal and the target marginal. - - t is a (dim_a * dim_b, ) metric vector (flattened version of T) + + - :math:`\mathbf{c}` is the flattened version of the cost matrix :math:`C` + - :math:`\mathbf{y}` is the concatenation of vectors :math:`\mathbf{a}` \ + and :math:`\mathbf{b}` + - :math:`H` is a metric matrix, see :ref:`[41] <references-regpath>` for \ + the design of :math:`H`. The matrix product :math:`H\mathbf{t}` \ + computes both the source marginal and the target marginals. + - :math:`\mathbf{t}` is the flattened version of the transport plan \ + :math:`T` + Parameters ---------- a : np.ndarray (dim_a,) @@ -47,14 +61,16 @@ def recast_ot_as_lasso(a, b, C): Histogram of dimension dim_b C : np.ndarray, shape (dim_a, dim_b) Cost matrix + Returns ------- H : np.ndarray (dim_a+dim_b, dim_a*dim_b) - Auxiliary matrix constituted by 0 and 1 + Design matrix that contains only 0 and 1 y : np.ndarray (ns + nt, ) - Concatenation of histogram a and histogram b + Concatenation of histograms :math:`\mathbf{a}` and :math:`\mathbf{b}` c : np.ndarray (ns * nt, ) - Flattened array of cost matrix + Flattened array of the cost matrix + Examples -------- >>> import ot @@ -73,12 +89,12 @@ def recast_ot_as_lasso(a, b, C): >>> c array([16., 25., 28., 16., 40., 36.]) + References ---------- - [Chapel et al., 2021]: - Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized - linear regression. + linear regression. NeurIPS. """ dim_a = np.shape(a)[0] @@ -97,33 +113,47 @@ def recast_ot_as_lasso(a, b, C): def recast_semi_relaxed_as_lasso(a, b, C): - r"""This function recasts the semi-relaxed l2-UOT problem as Lasso problem + r"""This function recasts the semi-relaxed l2-UOT problem as Lasso problem. .. math:: - semi-relaxed UOT = \min_T <C, T> + \lambda \|T 1_m - a\|_2^2 + + \text{semi-relaxed UOT} = \min_T <C, T> + + \lambda \|T 1_m - \mathbf{a}\|_2^2 + s.t. - T^T 1_n = b - t \geq 0 + T^T 1_n = \mathbf{b} + + \mathbf{t} \geq 0 + where : - - C is the (dim_a, dim_b) metric cost matrix - - :math:`\lambda` is the l2-regularization coefficient - - a and b are source and target distributions - - T is the transport plan to optimize + + - :math:`C` is the metric cost matrix + - :math:`\lambda` is the l2-regularization parameter + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the source and target \ + distributions + - :math:`T` is the transport plan to optimize The problem above can be reformulated as follows + .. math:: - semi-relaxed UOT2 = \min_t \gamma c^T t + 0.5 * \|H_r t - a\|_2^2 + \text{semi-relaxed UOT2} = \min_t \gamma \mathbf{c}^T t + + 0.5 * \|H_r \mathbf{t} - \mathbf{a}\|_2^2 + s.t. - H_c t = b - t \geq 0 + H_c \mathbf{t} = \mathbf{b} + + \mathbf{t} \geq 0 + where : - - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C) - - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient - - H_r is a (dim_a, dim_a * dim_b) metric matrix, - which computes the sum along the rows of transport plan T - - H_c is a (dim_b, dim_a * dim_b) metric matrix, - which computes the sum along the columns of transport plan T - - t is a (dim_a * dim_b, ) metric vector (flattened version of T) + + - :math:`\mathbf{c}` is flattened version of the cost matrix :math:`C` + - :math:`\gamma = 1/\lambda` is the l2-regularization parameter + - :math:`H_r` is a metric matrix which computes the sum along the \ + rows of the transport plan :math:`T` + - :math:`H_c` is a metric matrix which computes the sum along the \ + columns of the transport plan :math:`T` + - :math:`\mathbf{t}` is the flattened version of :math:`T` + Parameters ---------- a : np.ndarray (dim_a,) @@ -132,16 +162,18 @@ def recast_semi_relaxed_as_lasso(a, b, C): Histogram of dimension dim_b C : np.ndarray, shape (dim_a, dim_b) Cost matrix + Returns ------- Hr : np.ndarray (dim_a, dim_a * dim_b) Auxiliary matrix constituted by 0 and 1, which computes - the sum along the rows of transport plan T + the sum along the rows of transport plan :math:`T` Hc : np.ndarray (dim_b, dim_a * dim_b) Auxiliary matrix constituted by 0 and 1, which computes - the sum along the columns of transport plan T + the sum along the columns of transport plan :math:`T` c : np.ndarray (ns * nt, ) - Flattened array of cost matrix + Flattened array of the cost matrix + Examples -------- >>> import ot @@ -179,49 +211,60 @@ def recast_semi_relaxed_as_lasso(a, b, C): def ot_next_gamma(phi, delta, HtH, Hty, c, active_index, current_gamma): r""" This function computes the next value of gamma if a variable - will be added in next iteration of the regularization path + is added in the next iteration of the regularization path. We look for the largest value of gamma such that the gradient of an inactive variable vanishes + .. math:: - \max_{i \in \bar{A}} \frac{h_i^T(H_A \phi - y)}{h_i^T H_A \delta - c_i} + \max_{i \in \bar{A}} \frac{\mathbf{h}_i^T(H_A \phi - \mathbf{y})} + {\mathbf{h}_i^T H_A \delta - \mathbf{c}_i} + where : + - A is the current active set - - h_i is the ith column of auxiliary matrix H - - H_A is the sub-matrix constructed by the columns of H - whose indices belong to the active set A - - c_i is the ith element of cost vector c - - y is the concatenation of source and target distribution - - :math:`\phi` is the intercept of the solutions in current iteration - - :math:`\delta` is the slope of the solutions in current iteration + - :math:`\mathbf{h}_i` is the :math:`i` th column of the design \ + matrix :math:`{H}` + - :math:`{H}_A` is the sub-matrix constructed by the columns of \ + :math:`{H}` whose indices belong to the active set A + - :math:`\mathbf{c}_i` is the :math:`i` th element of the cost vector \ + :math:`\mathbf{c}` + - :math:`\mathbf{y}` is the concatenation of the source and target \ + distributions + - :math:`\phi` is the intercept of the solutions at the current iteration + - :math:`\delta` is the slope of the solutions at the current iteration + Parameters ---------- - phi : np.ndarray (|A|, ) - Intercept of the solutions in current iteration (t is piecewise linear) - delta : np.ndarray (|A|, ) - Slope of the solutions in current iteration (t is piecewise linear) + phi : np.ndarray (size(A), ) + Intercept of the solutions at the current iteration + delta : np.ndarray (size(A), ) + Slope of the solutions at the current iteration HtH : np.ndarray (dim_a * dim_b, dim_a * dim_b) - Matrix product of H^T H + Matrix product of :math:`{H}^T {H}` Hty : np.ndarray (dim_a + dim_b, ) - Matrix product of H^T y + Matrix product of :math:`{H}^T \mathbf{y}` c: np.ndarray (dim_a * dim_b, ) - Flattened array of cost matrix C + Flattened array of the cost matrix :math:`{C}` active_index : list Indices of active variables current_gamma : float - Value of regularization coefficient at the start of current iteration + Value of the regularization parameter at the beginning of the current \ + iteration + Returns ------- next_gamma : float Value of gamma if a variable is added to active set in next iteration next_active_index : int Index of variable to be activated + + References ---------- - [Chapel et al., 2021]: - Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized - linear regression. + linear regression. NeurIPS. """ M = (HtH[:, active_index].dot(phi) - Hty) / \ (HtH[:, active_index].dot(delta) - c + 1e-16) @@ -237,56 +280,65 @@ def semi_relaxed_next_gamma(phi, delta, phi_u, delta_u, HrHr, Hc, Hra, By taking the Lagrangian form of the problem, we obtain a similar update as the two-sided relaxed UOT + .. math:: - \max_{i \in \bar{A}} \frac{h_{r i}^T(H_{r A} \phi - a) + h_{c i}^T - \phi_u}{h_{r i}^T H_{r A} \delta + h_{c i} \delta_u - c_i} + + \max_{i \in \bar{A}} \frac{\mathbf{h}_{ri}^T(H_{rA} \phi - \mathbf{a}) + + \mathbf{h}_{c i}^T\phi_u}{\mathbf{h}_{r i}^T H_{r A} \delta + \ + \mathbf{h}_{c i} \delta_u - \mathbf{c}_i} + where : + - A is the current active set - - h_{r i} is the ith column of the matrix H_r - - h_{c i} is the ith column of the matrix H_c - - H_{r A} is the sub-matrix constructed by the columns of H_r - whose indices belong to the active set A - - c_i is the ith element of cost vector c - - y is the concatenation of source and target distribution + - :math:`\mathbf{h}_{r i}` is the ith column of the matrix :math:`H_r` + - :math:`\mathbf{h}_{c i}` is the ith column of the matrix :math:`H_c` + - :math:`H_{r A}` is the sub-matrix constructed by the columns of \ + :math:`H_r` whose indices belong to the active set A + - :math:`\mathbf{c}_i` is the :math:`i` th element of cost vector \ + :math:`\mathbf{c}` - :math:`\phi` is the intercept of the solutions in current iteration - :math:`\delta` is the slope of the solutions in current iteration - - :math:`\phi_u` is the intercept of Lagrange parameter in current - iteration - - :math:`\delta_u` is the slope of Lagrange parameter in current iteration + - :math:`\phi_u` is the intercept of Lagrange parameter at the \ + current iteration + - :math:`\delta_u` is the slope of Lagrange parameter at the \ + current iteration + Parameters ---------- - phi : np.ndarray (|A|, ) - Intercept of the solutions in current iteration (t is piecewise linear) - delta : np.ndarray (|A|, ) - Slope of the solutions in current iteration (t is piecewise linear) + phi : np.ndarray (size(A), ) + Intercept of the solutions at the current iteration + delta : np.ndarray (size(A), ) + Slope of the solutions at the current iteration phi_u : np.ndarray (dim_b, ) - Intercept of the Lagrange parameter in current iteration (also linear) + Intercept of the Lagrange parameter at the current iteration delta_u : np.ndarray (dim_b, ) - Slope of the Lagrange parameter in current iteration (also linear) + Slope of the Lagrange parameter at the current iteration HrHr : np.ndarray (dim_a * dim_b, dim_a * dim_b) - Matrix product of H_r^T H_r + Matrix product of :math:`H_r^T H_r` Hc : np.ndarray (dim_b, dim_a * dim_b) - Matrix that computes the sum along the columns of transport plan T + Matrix that computes the sum along the columns of the transport plan \ + :math:`T` Hra : np.ndarray (dim_a * dim_b, ) - Matrix product of H_r^T a + Matrix product of :math:`H_r^T \mathbf{a}` c: np.ndarray (dim_a * dim_b, ) - Flattened array of cost matrix C + Flattened array of cost matrix :math:`C` active_index : list Indices of active variables current_gamma : float Value of regularization coefficient at the start of current iteration + Returns ------- next_gamma : float Value of gamma if a variable is added to active set in next iteration next_active_index : int Index of variable to be activated + References ---------- - [Chapel et al., 2021]: - Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized - linear regression. + linear regression. NeurIPS. """ M = (HrHr[:, active_index].dot(phi) - Hra + Hc.T.dot(phi_u)) / \ @@ -297,37 +349,48 @@ def semi_relaxed_next_gamma(phi, delta, phi_u, delta_u, HrHr, Hc, Hra, def compute_next_removal(phi, delta, current_gamma): - r""" This function computes the next value of gamma if a variable - is removed in next iteration of regularization path + r""" This function computes the next gamma value if a variable + is removed at the next iteration of the regularization path. + + We look for the largest value of the regularization parameter such that + an element of the current solution vanishes - We look for the largest value of gamma such that - an element of current solution vanishes .. math:: \max_{j \in A} \frac{\phi_j}{\delta_j} + where : + - A is the current active set - - phi_j is the jth element of the intercept of current solution - - delta_j is the jth elemnt of the slope of current solution + - :math:`\phi_j` is the :math:`j` th element of the intercept of the \ + current solution + - :math:`\delta_j` is the :math:`j` th element of the slope of the \ + current solution + + Parameters ---------- - phi : np.ndarray (|A|, ) - Intercept of the solutions in current iteration (t is piecewise linear) - delta : np.ndarray (|A|, ) - Slope of the solutions in current iteration (t is piecewise linear) + phi : ndarray, shape (size(A), ) + Intercept of the solution at the current iteration + delta : ndarray, shape (size(A), ) + Slope of the solution at the current iteration current_gamma : float - Value of regularization coefficient at the start of current iteration + Value of the regularization parameter at the beginning of the \ + current iteration + Returns ------- next_removal_gamma : float - Value of gamma if a variable is removed in next iteration + Gamma value if a variable is removed at the next iteration next_removal_index : int - Index of the variable to remove in next iteration + Index of the variable to be removed at the next iteration + + + .. _references-regpath: References ---------- - [Chapel et al., 2021]: - Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized - linear regression. + linear regression. NeurIPS. """ r_candidate = phi / (delta - 1e-16) r_candidate[r_candidate >= (1 - 1e-8) * current_gamma] = 0 @@ -335,56 +398,74 @@ def compute_next_removal(phi, delta, current_gamma): def complement_schur(M_current, b, d, id_pop): - r""" This function computes the inverse of matrix in regularization path - using Schur complement + r""" This function computes the inverse of the design matrix in the \ + regularization path using the Schur complement. Two cases may arise: + + Case 1: one variable is added to the active set + - Two cases may arise: Firstly one variable is added to the active set .. math:: M_{k+1}^{-1} = \begin{bmatrix} - M_{k}^{-1} + s^{-1} M_{k}^{-1} b b^T M_{k}^{-1} & -s^{-1} \\ - - s^{-1} b^T M_{k}^{-1} & s^{-1} + M_{k}^{-1} + s^{-1} M_{k}^{-1} \mathbf{b} \mathbf{b}^T M_{k}^{-1} \ + & - M_{k}^{-1} \mathbf{b} s^{-1} \\ + - s^{-1} \mathbf{b}^T M_{k}^{-1} & s^{-1} \end{bmatrix} + + where : - - :math:`M_k^{-1}` is the inverse of matrix in previous iteration and - :math:`M_k` is the upper left block matrix in Schur formulation - - b is the upper right block matrix in Schur formulation. In our case, - b is reduced to a column vector and b^T is the lower left block matrix - - s is the Schur complement, given by - :math:`s = d - b^T M_{k}^{-1} b` in our case - - Secondly, one variable is removed from the active set + + - :math:`M_k^{-1}` is the inverse of the design matrix :math:`H_A^tH_A` \ + of the previous iteration + - :math:`\mathbf{b}` is the last column of :math:`M_{k}` + - :math:`s` is the Schur complement, given by \ + :math:`s = \mathbf{d} - \mathbf{b}^T M_{k}^{-1} \mathbf{b}` + + Case 2: one variable is removed from the active set. + .. math:: - M_{k+1}^{-1} = M^{-1}_{A_k \backslash q} - + M_{k+1}^{-1} = M^{-1}_{k \backslash q} - \frac{r_{-q,q} r^{T}_{-q,q}}{r_{q,q}} + where : - - q is the index of column and row to delete - - :math:`M^{-1}_{A_k \backslash q}` is the previous inverse matrix - without qth column and qth row - - r_{-q,q} is the qth column of :math:`M^{-1}_{k}` without the qth element - - r_{q, q} is the element of qth column and qth row in :math:`M^{-1}_{k}` + + - :math:`q` is the index of column and row to delete + - :math:`M^{-1}_{k \backslash q}` is the previous inverse matrix deprived \ + of the :math:`q` th column and :math:`q` th row + - :math:`r_{-q,q}` is the :math:`q` th column of :math:`M^{-1}_{k}` \ + without the :math:`q` th element + - :math:`r_{q, q}` is the element of :math:`q` th column and :math:`q` th \ + row in :math:`M^{-1}_{k}` + + Parameters ---------- - M_current : np.ndarray (|A|-1, |A|-1) - Inverse matrix in previous iteration - b : np.ndarray (|A|-1, ) - Upper right matrix in Schur complement, a column vector in our case + M_current : ndarray, shape (size(A)-1, size(A)-1) + Inverse matrix of :math:`H_A^tH_A` at the previous iteration, with \ + size(A) the size of the active set + b : ndarray, shape (size(A)-1, ) + None for case 2 (removal), last column of :math:`M_{k}` for case 1 \ + (addition) d : float - Lower right matrix in Schur complement, a scalar in our case - id_pop + should be equal to 2 when UOT and 1 for the semi-relaxed OT + id_pop : int Index of the variable to be removed, equal to -1 - if none of the variables is deleted in current iteration + if no variable is deleted at the current iteration + + Returns ------- - M : np.ndarray (|A|, |A|) - Inverse matrix needed in current iteration + M : ndarray, shape (size(A), size(A)) + Inverse matrix of :math:`H_A^tH_A` of the current iteration + + References ---------- - [Chapel et al., 2021]: - Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized - linear regression. + linear regression. NeurIPS. """ + if b is None: b = M_current[id_pop, :] b = np.delete(b, id_pop) @@ -409,33 +490,39 @@ def complement_schur(M_current, b, d, id_pop): def construct_augmented_H(active_index, m, Hc, HrHr): - r""" This function construct an augmented matrix for the first iteration of - semi-relaxed regularization path + r""" This function constructs an augmented matrix for the first iteration + of the semi-relaxed regularization path .. math:: - Augmented_H = + \text{Augmented}_H = \begin{bmatrix} 0 & H_{c A} \\ H_{c A}^T & H_{r A}^T H_{r A} \end{bmatrix} + where : - - H_{r A} is the sub-matrix constructed by the columns of H_r - whose indices belong to the active set A - - H_{c A} is the sub-matrix constructed by the columns of H_c - whose indices belong to the active set A + + - :math:`H_{r A}` is the sub-matrix constructed by the columns of \ + :math:`H_r` whose indices belong to the active set A + - :math:`H_{c A}` is the sub-matrix constructed by the columns of \ + :math:`H_c` whose indices belong to the active set A + + Parameters ---------- active_index : list - Indices of active variables + Indices of the active variables m : int Length of the target distribution Hc : np.ndarray (dim_b, dim_a * dim_b) - Matrix that computes the sum along the columns of transport plan T + Matrix that computes the sum along the columns of the transport plan \ + :math:`T` HrHr : np.ndarray (dim_a * dim_b, dim_a * dim_b) - Matrix product of H_r^T H_r + Matrix product of :math:`H_r^T H_r` + Returns ------- - H_augmented : np.ndarray (dim_b + |A|, dim_b + |A|) + H_augmented : np.ndarray (dim_b + size(A), dim_b + size(A)) Augmented matrix for the first iteration of the semi-relaxed regularization path """ @@ -451,18 +538,27 @@ def fully_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, r"""This function gives the regularization path of l2-penalized UOT problem The problem to optimize is the Lasso reformulation of the l2-penalized UOT: + .. math:: - \min_t \gamma c^T t + 0.5 * \|H t - y\|_2^2 + \min_t \gamma \mathbf{c}^T \mathbf{t} + + 0.5 * \|{H} \mathbf{t} - \mathbf{y}\|_2^2 + s.t. - t \geq 0 + \mathbf{t} \geq 0 + where : - - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C) + + - :math:`\mathbf{c}` is the flattened version of the cost matrix \ + :math:`{C}` - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient - - y is the concatenation of vectors a and b, defined as y^T = [a^T b^T] - - H is a (dim_a + dim_b, dim_a * dim_b) metric matrix, - see [Chapel et al., 2021] for the design of H. The matrix product Ht - computes both the source marginal and the target marginal. - - t is a (dim_a * dim_b, ) metric vector (flattened version of T) + - :math:`\mathbf{y}` is the concatenation of vectors :math:`\mathbf{a}` \ + and :math:`\mathbf{b}`, defined as \ + :math:`\mathbf{y}^T = [\mathbf{a}^T \mathbf{b}^T]` + - :math:`{H}` is a design matrix, see :ref:`[41] <references-regpath>` \ + for the design of :math:`{H}`. The matrix product :math:`H\mathbf{t}` \ + computes both the source marginal and the target marginals. + - :math:`\mathbf{t}` is the flattened version of the transport matrix + Parameters ---------- a : np.ndarray (dim_a,) @@ -478,11 +574,12 @@ def fully_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, Returns ------- t : np.ndarray (dim_a*dim_b, ) - Flattened vector of optimal transport matrix + Flattened vector of the optimal transport matrix t_list : list - List of solutions in regularization path + List of solutions in the regularization path gamma_list : list - List of regularization coefficient in regularization path + List of regularization coefficients in the regularization path + Examples -------- >>> import ot @@ -502,10 +599,9 @@ def fully_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, References ---------- - [Chapel et al., 2021]: - Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized - linear regression. + linear regression. NeurIPS. """ n = np.shape(a)[0] @@ -580,22 +676,32 @@ def fully_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, def semi_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, itmax=50000): r"""This function gives the regularization path of semi-relaxed - l2-UOT problem + l2-UOT problem. The problem to optimize is the Lasso reformulation of the l2-penalized UOT: + .. math:: - \min_t \gamma c^T t + 0.5 * \|H_r t - a\|_2^2 + + \min_t \gamma \mathbf{c}^T t + + 0.5 * \|H_r \mathbf{t} - \mathbf{a}\|_2^2 + s.t. - H_c t = b - t \geq 0 + H_c \mathbf{t} = \mathbf{b} + + \mathbf{t} \geq 0 + where : - - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C) - - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient - - H_r is a (dim_a, dim_a * dim_b) metric matrix, - which computes the sum along the rows of transport plan T - - H_c is a (dim_b, dim_a * dim_b) metric matrix, - which computes the sum along the columns of transport plan T - - t is a (dim_a * dim_b, ) metric vector (flattened version of T) + + - :math:`\mathbf{c}` is the flattened version of the cost matrix \ + :math:`C` + - :math:`\gamma = 1/\lambda` is the l2-regularization parameter + - :math:`H_r` is a matrix that computes the sum along the rows of \ + the transport plan :math:`T` + - :math:`H_c` is a matrix that computes the sum along the columns of \ + the transport plan :math:`T` + - :math:`\mathbf{t}` is the flattened version of the transport plan \ + :math:`T` + Parameters ---------- a : np.ndarray (dim_a,) @@ -608,14 +714,16 @@ def semi_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, l2-regularization coefficient itmax: int (optional) Maximum number of iteration + Returns ------- t : np.ndarray (dim_a*dim_b, ) - Flattened vector of optimal transport matrix + Flattened vector of the (unregularized) optimal transport matrix t_list : list - List of solutions in regularization path + List of all the optimal transport vectors of the regularization path gamma_list : list - List of regularization coefficient in regularization path + List of the regularization parameters in the path + Examples -------- >>> import ot @@ -635,10 +743,9 @@ def semi_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, References ---------- - [Chapel et al., 2021]: - Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized - linear regression. + linear regression. NeurIPS. """ n = np.shape(a)[0] @@ -722,8 +829,44 @@ def semi_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, def regularization_path(a: np.array, b: np.array, C: np.array, reg=1e-4, semi_relaxed=False, itmax=50000): - r"""This function combines both the semi-relaxed and the fully-relaxed - regularization paths of l2-UOT problem + r"""This function provides all the solutions of the regularization path \ + of the l2-UOT problem :ref:`[41] <references-regpath>`. + + The problem to optimize is the Lasso reformulation of the l2-penalized UOT: + + .. math:: + \min_t \gamma \mathbf{c}^T \mathbf{t} + + 0.5 * \|{H} \mathbf{t} - \mathbf{y}\|_2^2 + + s.t. + \mathbf{t} \geq 0 + + where : + + - :math:`\mathbf{c}` is the flattened version of the cost matrix \ + :math:`{C}` + - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient + - :math:`\mathbf{y}` is the concatenation of vectors :math:`\mathbf{a}` \ + and :math:`\mathbf{b}`, defined as \ + :math:`\mathbf{y}^T = [\mathbf{a}^T \mathbf{b}^T]` + - :math:`{H}` is a design matrix, see :ref:`[41] <references-regpath>` \ + for the design of :math:`{H}`. The matrix product :math:`H\mathbf{t}` \ + computes both the source marginal and the target marginals. + - :math:`\mathbf{t}` is the flattened version of the transport matrix + + For the semi-relaxed problem, it optimizes the Lasso reformulation of the + l2-penalized UOT: + + .. math:: + + \min_t \gamma \mathbf{c}^T \mathbf{t} + + 0.5 * \|H_r \mathbf{t} - \mathbf{a}\|_2^2 + + s.t. + H_c \mathbf{t} = \mathbf{b} + + \mathbf{t} \geq 0 + Parameters ---------- @@ -736,23 +879,24 @@ def regularization_path(a: np.array, b: np.array, C: np.array, reg=1e-4, reg: float (optional) l2-regularization coefficient semi_relaxed : bool (optional) - Give the semi-relaxed path if true + Give the semi-relaxed path if True itmax: int (optional) Maximum number of iteration + Returns ------- t : np.ndarray (dim_a*dim_b, ) - Flattened vector of optimal transport matrix + Flattened vector of the (unregularized) optimal transport matrix t_list : list - List of solutions in regularization path + List of all the optimal transport vectors of the regularization path gamma_list : list - List of regularization coefficient in regularization path + List of the regularization parameters in the path + References ---------- - [Chapel et al., 2021]: - Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized - linear regression. + linear regression. NeurIPS. """ if semi_relaxed: t, t_list, gamma_list = semi_relaxed_path(a, b, C, reg=reg, @@ -765,27 +909,33 @@ def regularization_path(a: np.array, b: np.array, C: np.array, reg=1e-4, def compute_transport_plan(gamma, gamma_list, Pi_list): r""" Given the regularization path, this function computes the transport - plan for any value of gamma by the piecewise linearity of the path + plan for any value of gamma thanks to the piecewise linearity of the path. .. math:: t(\gamma) = \phi(\gamma) - \gamma \delta(\gamma) - where : - - :math:`\gamma` is the regularization coefficient + + where: + + - :math:`\gamma` is the regularization parameter - :math:`\phi(\gamma)` is the corresponding intercept - :math:`\delta(\gamma)` is the corresponding slope - - t is a (dim_a * dim_b, ) vector (flattened version of transport matrix) + - :math:`\mathbf{t}` is the flattened version of the transport matrix + Parameters ---------- gamma : float Regularization coefficient gamma_list : list - List of regularization coefficients in regularization path + List of regularization parameters of the regularization path Pi_list : list - List of solutions in regularization path + List of all the solutions of the regularization path + Returns ------- t : np.ndarray (dim_a*dim_b, ) - Transport vector corresponding to the given value of gamma + Vectorization of the transport plan corresponding to the given value + of gamma + Examples -------- >>> import ot @@ -804,12 +954,13 @@ def compute_transport_plan(gamma, gamma_list, Pi_list): array([0. , 0. , 0. , 0.19722222, 0.05555556, 0. , 0. , 0.24722222, 0. ]) + + .. _references-regpath: References ---------- - [Chapel et al., 2021]: - Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized - linear regression. + linear regression. NeurIPS. """ if gamma >= gamma_list[0]: diff --git a/ot/stochastic.py b/ot/stochastic.py index 693675f..61be9bb 100644 --- a/ot/stochastic.py +++ b/ot/stochastic.py @@ -4,12 +4,14 @@ Stochastic solvers for regularized OT. """ -# Author: Kilian Fatras <kilian.fatras@gmail.com> +# Authors: Kilian Fatras <kilian.fatras@gmail.com> +# Rémi Flamary <remi.flamary@polytechnique.edu> # # License: MIT License import numpy as np - +from .utils import dist +from .backend import get_backend ############################################################################## # Optimization toolbox for SEMI - DUAL problems @@ -747,3 +749,239 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1, return pi, log else: return pi + + +################################################################################ +# Losses for stochastic optimization +################################################################################ + +def loss_dual_entropic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidean'): + r""" + Compute the dual loss of the entropic OT as in equation (6)-(7) of [19] + + This loss is backend compatible and can be used for stochastic optimization + of the dual potentials. It can be used on the full dataset (beware of + memory) or on minibatches. + + + Parameters + ---------- + u : array-like, shape (ns,) + Source dual potential + v : array-like, shape (nt,) + Target dual potential + xs : array-like, shape (ns,d) + Source samples + xt : array-like, shape (ns,d) + Target samples + reg : float + Regularization term > 0 (default=1) + ws : array-like, shape (ns,), optional + Source sample weights (default unif) + wt : array-like, shape (ns,), optional + Target sample weights (default unif) + metric : string, callable + Ground metric for OT (default quadratic). Can be given as a callable + function taking (xs,xt) as parameters. + + Returns + ------- + dual_loss : array-like + Dual loss (to maximize) + + + References + ---------- + .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018) + """ + + nx = get_backend(u, v, xs, xt) + + if ws is None: + ws = nx.ones(xs.shape[0], type_as=xs) / xs.shape[0] + + if wt is None: + wt = nx.ones(xt.shape[0], type_as=xt) / xt.shape[0] + + if callable(metric): + M = metric(xs, xt) + else: + M = dist(xs, xt, metric=metric) + + F = -reg * nx.exp((u[:, None] + v[None, :] - M) / reg) + + return nx.sum(u * ws) + nx.sum(v * wt) + nx.sum(ws[:, None] * F * wt[None, :]) + + +def plan_dual_entropic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidean'): + r""" + Compute the primal OT plan the entropic OT as in equation (8) of [19] + + This loss is backend compatible and can be used for stochastic optimization + of the dual potentials. It can be used on the full dataset (beware of + memory) or on minibatches. + + + Parameters + ---------- + u : array-like, shape (ns,) + Source dual potential + v : array-like, shape (nt,) + Target dual potential + xs : array-like, shape (ns,d) + Source samples + xt : array-like, shape (ns,d) + Target samples + reg : float + Regularization term > 0 (default=1) + ws : array-like, shape (ns,), optional + Source sample weights (default unif) + wt : array-like, shape (ns,), optional + Target sample weights (default unif) + metric : string, callable + Ground metric for OT (default quadratic). Can be given as a callable + function taking (xs,xt) as parameters. + + Returns + ------- + G : array-like + Primal OT plan + + + References + ---------- + .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018) + """ + + nx = get_backend(u, v, xs, xt) + + if ws is None: + ws = nx.ones(xs.shape[0], type_as=xs) / xs.shape[0] + + if wt is None: + wt = nx.ones(xt.shape[0], type_as=xt) / xt.shape[0] + + if callable(metric): + M = metric(xs, xt) + else: + M = dist(xs, xt, metric=metric) + + H = nx.exp((u[:, None] + v[None, :] - M) / reg) + + return ws[:, None] * H * wt[None, :] + + +def loss_dual_quadratic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidean'): + r""" + Compute the dual loss of the quadratic regularized OT as in equation (6)-(7) of [19] + + This loss is backend compatible and can be used for stochastic optimization + of the dual potentials. It can be used on the full dataset (beware of + memory) or on minibatches. + + + Parameters + ---------- + u : array-like, shape (ns,) + Source dual potential + v : array-like, shape (nt,) + Target dual potential + xs : array-like, shape (ns,d) + Source samples + xt : array-like, shape (ns,d) + Target samples + reg : float + Regularization term > 0 (default=1) + ws : array-like, shape (ns,), optional + Source sample weights (default unif) + wt : array-like, shape (ns,), optional + Target sample weights (default unif) + metric : string, callable + Ground metric for OT (default quadratic). Can be given as a callable + function taking (xs,xt) as parameters. + + Returns + ------- + dual_loss : array-like + Dual loss (to maximize) + + + References + ---------- + .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018) + """ + + nx = get_backend(u, v, xs, xt) + + if ws is None: + ws = nx.ones(xs.shape[0], type_as=xs) / xs.shape[0] + + if wt is None: + wt = nx.ones(xt.shape[0], type_as=xt) / xt.shape[0] + + if callable(metric): + M = metric(xs, xt) + else: + M = dist(xs, xt, metric=metric) + + F = -1.0 / (4 * reg) * nx.maximum(u[:, None] + v[None, :] - M, 0.0)**2 + + return nx.sum(u * ws) + nx.sum(v * wt) + nx.sum(ws[:, None] * F * wt[None, :]) + + +def plan_dual_quadratic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidean'): + r""" + Compute the primal OT plan the quadratic regularized OT as in equation (8) of [19] + + This loss is backend compatible and can be used for stochastic optimization + of the dual potentials. It can be used on the full dataset (beware of + memory) or on minibatches. + + + Parameters + ---------- + u : array-like, shape (ns,) + Source dual potential + v : array-like, shape (nt,) + Target dual potential + xs : array-like, shape (ns,d) + Source samples + xt : array-like, shape (ns,d) + Target samples + reg : float + Regularization term > 0 (default=1) + ws : array-like, shape (ns,), optional + Source sample weights (default unif) + wt : array-like, shape (ns,), optional + Target sample weights (default unif) + metric : string, callable + Ground metric for OT (default quadratic). Can be given as a callable + function taking (xs,xt) as parameters. + + Returns + ------- + G : array-like + Primal OT plan + + + References + ---------- + .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018) + """ + + nx = get_backend(u, v, xs, xt) + + if ws is None: + ws = nx.ones(xs.shape[0], type_as=xs) / xs.shape[0] + + if wt is None: + wt = nx.ones(xt.shape[0], type_as=xt) / xt.shape[0] + + if callable(metric): + M = metric(xs, xt) + else: + M = dist(xs, xt, metric=metric) + + H = 1.0 / (2 * reg) * nx.maximum(u[:, None] + v[None, :] - M, 0.0) + + return ws[:, None] * H * wt[None, :] diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 15e180b..90c920c 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -4,13 +4,14 @@ Regularized Unbalanced OT solvers """ # Author: Hicham Janati <hicham.janati@inria.fr> +# Laetitia Chapel <laetitia.chapel@univ-ubs.fr> # License: MIT License from __future__ import division import warnings -import numpy as np -from scipy.special import logsumexp +from .backend import get_backend +from .utils import list_to_array # from .utils import unif, dist @@ -43,12 +44,12 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, Parameters ---------- - a : np.ndarray (dim_a,) + a : array-like (dim_a,) Unnormalized histogram of dimension `dim_a` - b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) + b : array-like (dim_b,) or array-like (dim_b, n_hists) One or multiple unnormalized histograms of dimension `dim_b`. If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i` - M : np.ndarray (dim_a, dim_b) + M : array-like (dim_a, dim_b) loss matrix reg : float Entropy regularization term > 0 @@ -70,12 +71,12 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, Returns ------- if n_hists == 1: - - gamma : (dim_a, dim_b) ndarray + - gamma : (dim_a, dim_b) array-like Optimal transportation matrix for the given parameters - log : dict log dictionary returned only if `log` is `True` else: - - ot_distance : (n_hists,) ndarray + - ot_distance : (n_hists,) array-like the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` - log : dict log dictionary returned only if `log` is `True` @@ -172,12 +173,12 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', Parameters ---------- - a : np.ndarray (dim_a,) + a : array-like (dim_a,) Unnormalized histogram of dimension `dim_a` - b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) + b : array-like (dim_b,) or array-like (dim_b, n_hists) One or multiple unnormalized histograms of dimension `dim_b`. If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i` - M : np.ndarray (dim_a, dim_b) + M : array-like (dim_a, dim_b) loss matrix reg : float Entropy regularization term > 0 @@ -198,7 +199,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', Returns ------- - ot_distance : (n_hists,) ndarray + ot_distance : (n_hists,) array-like the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` log : dict log dictionary returned only if `log` is `True` @@ -239,9 +240,10 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epslilon scaling :ref:`[9, 10] <references-sinkhorn-unbalanced2>` """ - b = np.asarray(b, dtype=np.float64) + b = list_to_array(b) if len(b.shape) < 2: b = b[:, None] + if method.lower() == 'sinkhorn': return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=numItermax, @@ -291,12 +293,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, Parameters ---------- - a : np.ndarray (dim_a,) + a : array-like (dim_a,) Unnormalized histogram of dimension `dim_a` - b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) + b : array-like (dim_b,) or array-like (dim_b, n_hists) One or multiple unnormalized histograms of dimension `dim_b` If many, compute all the OT distances (a, b_i) - M : np.ndarray (dim_a, dim_b) + M : array-like (dim_a, dim_b) loss matrix reg : float Entropy regularization term > 0 @@ -315,12 +317,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, Returns ------- if n_hists == 1: - - gamma : (dim_a, dim_b) ndarray + - gamma : (dim_a, dim_b) array-like Optimal transportation matrix for the given parameters - log : dict log dictionary returned only if `log` is `True` else: - - ot_distance : (n_hists,) ndarray + - ot_distance : (n_hists,) array-like the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` - log : dict log dictionary returned only if `log` is `True` @@ -354,17 +356,15 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, ot.optim.cg : General regularized OT """ - - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + M, a, b = list_to_array(M, a, b) + nx = get_backend(M, a, b) dim_a, dim_b = M.shape if len(a) == 0: - a = np.ones(dim_a, dtype=np.float64) / dim_a + a = nx.ones(dim_a, type_as=M) / dim_a if len(b) == 0: - b = np.ones(dim_b, dtype=np.float64) / dim_b + b = nx.ones(dim_b, type_as=M) / dim_b if len(b.shape) > 1: n_hists = b.shape[1] @@ -377,17 +377,14 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, # we assume that no distances are null except those of the diagonal of # distances if n_hists: - u = np.ones((dim_a, 1)) / dim_a - v = np.ones((dim_b, n_hists)) / dim_b + u = nx.ones((dim_a, 1), type_as=M) / dim_a + v = nx.ones((dim_b, n_hists), type_as=M) / dim_b a = a.reshape(dim_a, 1) else: - u = np.ones(dim_a) / dim_a - v = np.ones(dim_b) / dim_b + u = nx.ones(dim_a, type_as=M) / dim_a + v = nx.ones(dim_b, type_as=M) / dim_b - # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute - K = np.empty(M.shape, dtype=M.dtype) - np.divide(M, -reg, out=K) - np.exp(K, out=K) + K = nx.exp(M / (-reg)) fi = reg_m / (reg_m + reg) @@ -397,14 +394,14 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, uprev = u vprev = v - Kv = K.dot(v) + Kv = nx.dot(K, v) u = (a / Kv) ** fi - Ktu = K.T.dot(u) + Ktu = nx.dot(K.T, u) v = (b / Ktu) ** fi - if (np.any(Ktu == 0.) - or np.any(np.isnan(u)) or np.any(np.isnan(v)) - or np.any(np.isinf(u)) or np.any(np.isinf(v))): + if (nx.any(Ktu == 0.) + or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop warnings.warn('Numerical errors at iteration %s' % i) @@ -412,8 +409,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, v = vprev break - err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.) - err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.) + err_u = nx.max(nx.abs(u - uprev)) / max( + nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1. + ) + err_v = nx.max(nx.abs(v - vprev)) / max( + nx.max(nx.abs(v)), nx.max(nx.abs(vprev)), 1. + ) err = 0.5 * (err_u + err_v) if log: log['err'].append(err) @@ -426,11 +427,11 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, break if log: - log['logu'] = np.log(u + 1e-300) - log['logv'] = np.log(v + 1e-300) + log['logu'] = nx.log(u + 1e-300) + log['logv'] = nx.log(v + 1e-300) if n_hists: # return only loss - res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) + res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M) if log: return res, log else: @@ -475,12 +476,12 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 Parameters ---------- - a : np.ndarray (dim_a,) + a : array-like (dim_a,) Unnormalized histogram of dimension `dim_a` - b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) + b : array-like (dim_b,) or array-like (dim_b, n_hists) One or multiple unnormalized histograms of dimension `dim_b`. If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i` - M : np.ndarray (dim_a, dim_b) + M : array-like (dim_a, dim_b) loss matrix reg : float Entropy regularization term > 0 @@ -501,12 +502,12 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 Returns ------- if n_hists == 1: - - gamma : (dim_a, dim_b) ndarray + - gamma : (dim_a, dim_b) array-like Optimal transportation matrix for the given parameters - log : dict log dictionary returned only if `log` is `True` else: - - ot_distance : (n_hists,) ndarray + - ot_distance : (n_hists,) array-like the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` - log : dict log dictionary returned only if `log` is `True` @@ -538,17 +539,15 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 ot.optim.cg : General regularized OT """ - - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + a, b, M = list_to_array(a, b, M) + nx = get_backend(M, a, b) dim_a, dim_b = M.shape if len(a) == 0: - a = np.ones(dim_a, dtype=np.float64) / dim_a + a = nx.ones(dim_a, type_as=M) / dim_a if len(b) == 0: - b = np.ones(dim_b, dtype=np.float64) / dim_b + b = nx.ones(dim_b, type_as=M) / dim_b if len(b.shape) > 1: n_hists = b.shape[1] @@ -561,56 +560,52 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 # we assume that no distances are null except those of the diagonal of # distances if n_hists: - u = np.ones((dim_a, n_hists)) / dim_a - v = np.ones((dim_b, n_hists)) / dim_b + u = nx.ones((dim_a, n_hists), type_as=M) / dim_a + v = nx.ones((dim_b, n_hists), type_as=M) / dim_b a = a.reshape(dim_a, 1) else: - u = np.ones(dim_a) / dim_a - v = np.ones(dim_b) / dim_b + u = nx.ones(dim_a, type_as=M) / dim_a + v = nx.ones(dim_b, type_as=M) / dim_b # print(reg) - # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute - K = np.empty(M.shape, dtype=M.dtype) - np.divide(M, -reg, out=K) - np.exp(K, out=K) + K = nx.exp(-M / reg) fi = reg_m / (reg_m + reg) cpt = 0 err = 1. - alpha = np.zeros(dim_a) - beta = np.zeros(dim_b) + alpha = nx.zeros(dim_a, type_as=M) + beta = nx.zeros(dim_b, type_as=M) while (err > stopThr and cpt < numItermax): uprev = u vprev = v - Kv = K.dot(v) - f_alpha = np.exp(- alpha / (reg + reg_m)) - f_beta = np.exp(- beta / (reg + reg_m)) + Kv = nx.dot(K, v) + f_alpha = nx.exp(- alpha / (reg + reg_m)) + f_beta = nx.exp(- beta / (reg + reg_m)) if n_hists: f_alpha = f_alpha[:, None] f_beta = f_beta[:, None] u = ((a / (Kv + 1e-16)) ** fi) * f_alpha - Ktu = K.T.dot(u) + Ktu = nx.dot(K.T, u) v = ((b / (Ktu + 1e-16)) ** fi) * f_beta absorbing = False - if (u > tau).any() or (v > tau).any(): + if nx.any(u > tau) or nx.any(v > tau): absorbing = True if n_hists: - alpha = alpha + reg * np.log(np.max(u, 1)) - beta = beta + reg * np.log(np.max(v, 1)) + alpha = alpha + reg * nx.log(nx.max(u, 1)) + beta = beta + reg * nx.log(nx.max(v, 1)) else: - alpha = alpha + reg * np.log(np.max(u)) - beta = beta + reg * np.log(np.max(v)) - K = np.exp((alpha[:, None] + beta[None, :] - - M) / reg) - v = np.ones_like(v) - Kv = K.dot(v) - - if (np.any(Ktu == 0.) - or np.any(np.isnan(u)) or np.any(np.isnan(v)) - or np.any(np.isinf(u)) or np.any(np.isinf(v))): + alpha = alpha + reg * nx.log(nx.max(u)) + beta = beta + reg * nx.log(nx.max(v)) + K = nx.exp((alpha[:, None] + beta[None, :] - M) / reg) + v = nx.ones(v.shape, type_as=v) + Kv = nx.dot(K, v) + + if (nx.any(Ktu == 0.) + or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop warnings.warn('Numerical errors at iteration %s' % cpt) @@ -620,8 +615,9 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 if (cpt % 10 == 0 and not absorbing) or cpt == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), - 1.) + err = nx.max(nx.abs(u - uprev)) / max( + nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1. + ) if log: log['err'].append(err) if verbose: @@ -636,25 +632,30 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 "Try a larger entropy `reg` or a lower mass `reg_m`." + "Or a larger absorption threshold `tau`.") if n_hists: - logu = alpha[:, None] / reg + np.log(u) - logv = beta[:, None] / reg + np.log(v) + logu = alpha[:, None] / reg + nx.log(u) + logv = beta[:, None] / reg + nx.log(v) else: - logu = alpha / reg + np.log(u) - logv = beta / reg + np.log(v) + logu = alpha / reg + nx.log(u) + logv = beta / reg + nx.log(v) if log: log['logu'] = logu log['logv'] = logv if n_hists: # return only loss - res = logsumexp(np.log(M + 1e-100)[:, :, None] + logu[:, None, :] + - logv[None, :, :] - M[:, :, None] / reg, axis=(0, 1)) - res = np.exp(res) + res = nx.logsumexp( + nx.log(M + 1e-100)[:, :, None] + + logu[:, None, :] + + logv[None, :, :] + - M[:, :, None] / reg, + axis=(0, 1) + ) + res = nx.exp(res) if log: return res, log else: return res else: # return OT matrix - ot_matrix = np.exp(logu[:, None] + logv[None, :] - M / reg) + ot_matrix = nx.exp(logu[:, None] + logv[None, :] - M / reg) if log: return ot_matrix, log else: @@ -683,9 +684,9 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, Parameters ---------- - A : np.ndarray (dim, n_hists) + A : array-like (dim, n_hists) `n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim` - M : np.ndarray (dim, dim) + M : array-like (dim, dim) ground metric matrix for OT. reg : float Entropy regularization term > 0 @@ -693,7 +694,7 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, Marginal relaxation term > 0 tau : float Stabilization threshold for log domain absorption. - weights : np.ndarray (n_hists,) optional + weights : array-like (n_hists,) optional Weight of each distribution (barycentric coodinates) If None, uniform weights are used. numItermax : int, optional @@ -708,7 +709,7 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, Returns ------- - a : (dim,) ndarray + a : (dim,) array-like Unbalanced Wasserstein barycenter log : dict log dictionary return only if log==True in parameters @@ -726,9 +727,12 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, """ + A, M = list_to_array(A, M) + nx = get_backend(A, M) + dim, n_hists = A.shape if weights is None: - weights = np.ones(n_hists) / n_hists + weights = nx.ones(n_hists, type_as=A) / n_hists else: assert(len(weights) == A.shape[1]) @@ -737,47 +741,43 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, fi = reg_m / (reg_m + reg) - u = np.ones((dim, n_hists)) / dim - v = np.ones((dim, n_hists)) / dim + u = nx.ones((dim, n_hists), type_as=A) / dim + v = nx.ones((dim, n_hists), type_as=A) / dim # print(reg) - # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute - K = np.empty(M.shape, dtype=M.dtype) - np.divide(M, -reg, out=K) - np.exp(K, out=K) + K = nx.exp(-M / reg) fi = reg_m / (reg_m + reg) cpt = 0 err = 1. - alpha = np.zeros(dim) - beta = np.zeros(dim) - q = np.ones(dim) / dim + alpha = nx.zeros(dim, type_as=A) + beta = nx.zeros(dim, type_as=A) + q = nx.ones(dim, type_as=A) / dim for i in range(numItermax): - qprev = q.copy() - Kv = K.dot(v) - f_alpha = np.exp(- alpha / (reg + reg_m)) - f_beta = np.exp(- beta / (reg + reg_m)) + qprev = nx.copy(q) + Kv = nx.dot(K, v) + f_alpha = nx.exp(- alpha / (reg + reg_m)) + f_beta = nx.exp(- beta / (reg + reg_m)) f_alpha = f_alpha[:, None] f_beta = f_beta[:, None] u = ((A / (Kv + 1e-16)) ** fi) * f_alpha - Ktu = K.T.dot(u) + Ktu = nx.dot(K.T, u) q = (Ktu ** (1 - fi)) * f_beta - q = q.dot(weights) ** (1 / (1 - fi)) + q = nx.dot(q, weights) ** (1 / (1 - fi)) Q = q[:, None] v = ((Q / (Ktu + 1e-16)) ** fi) * f_beta absorbing = False - if (u > tau).any() or (v > tau).any(): + if nx.any(u > tau) or nx.any(v > tau): absorbing = True - alpha = alpha + reg * np.log(np.max(u, 1)) - beta = beta + reg * np.log(np.max(v, 1)) - K = np.exp((alpha[:, None] + beta[None, :] - - M) / reg) - v = np.ones_like(v) - Kv = K.dot(v) - if (np.any(Ktu == 0.) - or np.any(np.isnan(u)) or np.any(np.isnan(v)) - or np.any(np.isinf(u)) or np.any(np.isinf(v))): + alpha = alpha + reg * nx.log(nx.max(u, 1)) + beta = beta + reg * nx.log(nx.max(v, 1)) + K = nx.exp((alpha[:, None] + beta[None, :] - M) / reg) + v = nx.ones(v.shape, type_as=v) + Kv = nx.dot(K, v) + if (nx.any(Ktu == 0.) + or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop warnings.warn('Numerical errors at iteration %s' % cpt) @@ -786,8 +786,9 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, if (i % 10 == 0 and not absorbing) or i == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err = abs(q - qprev).max() / max(abs(q).max(), - abs(qprev).max(), 1.) + err = nx.max(nx.abs(q - qprev)) / max( + nx.max(nx.abs(q)), nx.max(nx.abs(qprev)), 1. + ) if log: log['err'].append(err) if verbose: @@ -804,8 +805,8 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, "Or a larger absorption threshold `tau`.") if log: log['niter'] = i - log['logu'] = np.log(u + 1e-300) - log['logv'] = np.log(v + 1e-300) + log['logu'] = nx.log(u + 1e-300) + log['logv'] = nx.log(v + 1e-300) return q, log else: return q @@ -833,15 +834,15 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, Parameters ---------- - A : np.ndarray (dim, n_hists) + A : array-like (dim, n_hists) `n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim` - M : np.ndarray (dim, dim) + M : array-like (dim, dim) ground metric matrix for OT. reg : float Entropy regularization term > 0 reg_m: float Marginal relaxation term > 0 - weights : np.ndarray (n_hists,) optional + weights : array-like (n_hists,) optional Weight of each distribution (barycentric coodinates) If None, uniform weights are used. numItermax : int, optional @@ -856,7 +857,7 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, Returns ------- - a : (dim,) ndarray + a : (dim,) array-like Unbalanced Wasserstein barycenter log : dict log dictionary return only if log==True in parameters @@ -874,40 +875,43 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, """ + A, M = list_to_array(A, M) + nx = get_backend(A, M) + dim, n_hists = A.shape if weights is None: - weights = np.ones(n_hists) / n_hists + weights = nx.ones(n_hists, type_as=A) / n_hists else: assert(len(weights) == A.shape[1]) if log: log = {'err': []} - K = np.exp(- M / reg) + K = nx.exp(-M / reg) fi = reg_m / (reg_m + reg) - v = np.ones((dim, n_hists)) - u = np.ones((dim, 1)) - q = np.ones(dim) + v = nx.ones((dim, n_hists), type_as=A) + u = nx.ones((dim, 1), type_as=A) + q = nx.ones(dim, type_as=A) err = 1. for i in range(numItermax): - uprev = u.copy() - vprev = v.copy() - qprev = q.copy() + uprev = nx.copy(u) + vprev = nx.copy(v) + qprev = nx.copy(q) - Kv = K.dot(v) + Kv = nx.dot(K, v) u = (A / Kv) ** fi - Ktu = K.T.dot(u) - q = ((Ktu ** (1 - fi)).dot(weights)) + Ktu = nx.dot(K.T, u) + q = nx.dot(Ktu ** (1 - fi), weights) q = q ** (1 / (1 - fi)) Q = q[:, None] v = (Q / Ktu) ** fi - if (np.any(Ktu == 0.) - or np.any(np.isnan(u)) or np.any(np.isnan(v)) - or np.any(np.isinf(u)) or np.any(np.isinf(v))): + if (nx.any(Ktu == 0.) + or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop warnings.warn('Numerical errors at iteration %s' % i) @@ -916,8 +920,9 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, q = qprev break # compute change in barycenter - err = abs(q - qprev).max() - err /= max(abs(q).max(), abs(qprev).max(), 1.) + err = nx.max(nx.abs(q - qprev)) / max( + nx.max(nx.abs(q)), nx.max(nx.abs(qprev)), 1.0 + ) if log: log['err'].append(err) # if barycenter did not change + at least 10 iterations - stop @@ -932,8 +937,8 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, if log: log['niter'] = i - log['logu'] = np.log(u + 1e-300) - log['logv'] = np.log(v + 1e-300) + log['logu'] = nx.log(u + 1e-300) + log['logv'] = nx.log(v + 1e-300) return q, log else: return q @@ -961,15 +966,15 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, Parameters ---------- - A : np.ndarray (dim, n_hists) + A : array-like (dim, n_hists) `n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim` - M : np.ndarray (dim, dim) + M : array-like (dim, dim) ground metric matrix for OT. reg : float Entropy regularization term > 0 reg_m: float Marginal relaxation term > 0 - weights : np.ndarray (n_hists,) optional + weights : array-like (n_hists,) optional Weight of each distribution (barycentric coodinates) If None, uniform weights are used. numItermax : int, optional @@ -984,7 +989,7 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, Returns ------- - a : (dim,) ndarray + a : (dim,) array-like Unbalanced Wasserstein barycenter log : dict log dictionary return only if log==True in parameters @@ -1025,3 +1030,225 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, log=log, **kwargs) else: raise ValueError("Unknown method '%s'." % method) + + +def mm_unbalanced(a, b, M, reg_m, div='kl', G0=None, numItermax=1000, + stopThr=1e-15, verbose=False, log=False): + r""" + Solve the unbalanced optimal transport problem and return the OT plan. + The function solves the following optimization problem: + + .. math:: + W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg_m} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + + \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + s.t. + \gamma \geq 0 + + where: + + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + unbalanced distributions + - div is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence + + The algorithm used for solving the problem is a maximization- + minimization algorithm as proposed in :ref:`[41] <references-regpath>` + + Parameters + ---------- + a : array-like (dim_a,) + Unnormalized histogram of dimension `dim_a` + b : array-like (dim_b,) + Unnormalized histogram of dimension `dim_b` + M : array-like (dim_a, dim_b) + loss matrix + reg_m: float + Marginal relaxation term > 0 + div: string, optional + Divergence to quantify the difference between the marginals. + Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) + G0: array-like (dim_a, dim_b) + Initialization of the transport matrix + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (> 0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + Returns + ------- + gamma : (dim_a, dim_b) array-like + Optimal transportation matrix for the given parameters + log : dict + log dictionary returned only if `log` is `True` + + Examples + -------- + >>> import ot + >>> import numpy as np + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[1., 36.],[9., 4.]] + >>> np.round(ot.unbalanced.mm_unbalanced(a, b, M, 1, 'kl'), 2) + array([[0.3 , 0. ], + [0. , 0.07]]) + >>> np.round(ot.unbalanced.mm_unbalanced(a, b, M, 1, 'l2'), 2) + array([[0.25, 0. ], + [0. , 0. ]]) + + + .. _references-regpath: + References + ---------- + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. NeurIPS. + See Also + -------- + ot.lp.emd : Unregularized OT + ot.unbalanced.sinkhorn_unbalanced : Entropic regularized OT + """ + M, a, b = list_to_array(M, a, b) + nx = get_backend(M, a, b) + + dim_a, dim_b = M.shape + + if len(a) == 0: + a = nx.ones(dim_a, type_as=M) / dim_a + if len(b) == 0: + b = nx.ones(dim_b, type_as=M) / dim_b + + if G0 is None: + G = a[:, None] * b[None, :] + else: + G = G0 + + if log: + log = {'err': [], 'G': []} + + if div == 'kl': + K = nx.exp(M / - reg_m / 2) + elif div == 'l2': + K = nx.maximum(a[:, None] + b[None, :] - M / reg_m / 2, + nx.zeros((dim_a, dim_b), type_as=M)) + else: + warnings.warn("The div parameter should be either equal to 'kl' or \ + 'l2': it has been set to 'kl'.") + div = 'kl' + K = nx.exp(M / - reg_m / 2) + + for i in range(numItermax): + Gprev = G + + if div == 'kl': + u = nx.sqrt(a / (nx.sum(G, 1) + 1e-16)) + v = nx.sqrt(b / (nx.sum(G, 0) + 1e-16)) + G = G * K * u[:, None] * v[None, :] + elif div == 'l2': + Gd = nx.sum(G, 0, keepdims=True) + nx.sum(G, 1, keepdims=True) + 1e-16 + G = G * K / Gd + + err = nx.sqrt(nx.sum((G - Gprev) ** 2)) + if log: + log['err'].append(err) + log['G'].append(G) + if verbose: + print('{:5d}|{:8e}|'.format(i, err)) + if err < stopThr: + break + + if log: + log['cost'] = nx.sum(G * M) + return G, log + else: + return G + + +def mm_unbalanced2(a, b, M, reg_m, div='kl', G0=None, numItermax=1000, + stopThr=1e-15, verbose=False, log=False): + r""" + Solve the unbalanced optimal transport problem and return the OT plan. + The function solves the following optimization problem: + + .. math:: + W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg_m} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + + \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + + s.t. + \gamma \geq 0 + + where: + + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + unbalanced distributions + - :math:`\mathrm{div}` is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence + + The algorithm used for solving the problem is a maximization- + minimization algorithm as proposed in :ref:`[41] <references-regpath>` + + Parameters + ---------- + a : array-like (dim_a,) + Unnormalized histogram of dimension `dim_a` + b : array-like (dim_b,) + Unnormalized histogram of dimension `dim_b` + M : array-like (dim_a, dim_b) + loss matrix + reg_m: float + Marginal relaxation term > 0 + div: string, optional + Divergence to quantify the difference between the marginals. + Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) + G0: array-like (dim_a, dim_b) + Initialization of the transport matrix + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (> 0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + ot_distance : array-like + the OT distance between :math:`\mathbf{a}` and :math:`\mathbf{b}` + log : dict + log dictionary returned only if `log` is `True` + + Examples + -------- + >>> import ot + >>> import numpy as np + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[1., 36.],[9., 4.]] + >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 1, 'l2'),2) + 0.25 + >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 1, 'kl'),2) + 0.57 + + References + ---------- + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. NeurIPS. + See Also + -------- + ot.lp.emd2 : Unregularized OT loss + ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss + """ + _, log_mm = mm_unbalanced(a, b, M, reg_m, div=div, G0=G0, + numItermax=numItermax, stopThr=stopThr, + verbose=verbose, log=True) + + if log: + return log_mm['cost'], log_mm + else: + return log_mm['cost'] diff --git a/ot/utils.py b/ot/utils.py index e6c93c8..a23ce7e 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -15,7 +15,7 @@ from scipy.spatial.distance import cdist import sys import warnings from inspect import signature -from .backend import get_backend +from .backend import get_backend, Backend __time_tic_toc = time.time() @@ -51,7 +51,8 @@ def kernel(x1, x2, method='gaussian', sigma=1, **kwargs): def laplacian(x): r"""Compute Laplacian matrix""" - L = np.diag(np.sum(x, axis=0)) - x + nx = get_backend(x) + L = nx.diag(nx.sum(x, axis=0)) - x return L @@ -116,7 +117,7 @@ def proj_simplex(v, z=1): return w -def unif(n): +def unif(n, type_as=None): r""" Return a uniform histogram of length `n` (simplex). @@ -124,13 +125,19 @@ def unif(n): ---------- n : int number of bins in the histogram + type_as : array_like + array of the same type of the expected output (numpy/pytorch/jax) Returns ------- - h : np.array (`n`,) + h : array_like (`n`,) histogram of length `n` such that :math:`\forall i, \mathbf{h}_i = \frac{1}{n}` """ - return np.ones((n,)) / n + if type_as is None: + return np.ones((n,)) / n + else: + nx = get_backend(type_as) + return nx.ones((n,), type_as=type_as) / n def clean_zeros(a, b, M): @@ -290,7 +297,8 @@ def cost_normalization(C, norm=None): def dots(*args): r""" dots function for multiple matrix multiply """ - return reduce(np.dot, args) + nx = get_backend(*args) + return reduce(nx.dot, args) def label_normalization(y, start=0): @@ -308,8 +316,9 @@ def label_normalization(y, start=0): y : array-like, shape (`n1`, ) The input vector of labels normalized according to given start value. """ + nx = get_backend(y) - diff = np.min(np.unique(y)) - start + diff = nx.min(nx.unique(y)) - start if diff != 0: y -= diff return y @@ -476,6 +485,19 @@ class BaseEstimator(object): arguments (no ``*args`` or ``**kwargs``). """ + nx: Backend = None + + def _get_backend(self, *arrays): + nx = get_backend( + *[input_ for input_ in arrays if input_ is not None] + ) + if nx.__name__ in ("jax", "tf"): + raise TypeError( + """JAX or TF arrays have been received but domain + adaptation does not support those backend.""") + self.nx = nx + return nx + @classmethod def _get_param_names(cls): r"""Get parameter names for the estimator""" diff --git a/ot/weak.py b/ot/weak.py new file mode 100644 index 0000000..f7d5b23 --- /dev/null +++ b/ot/weak.py @@ -0,0 +1,124 @@ +""" +Weak optimal ransport solvers +""" + +# Author: Remi Flamary <remi.flamary@polytehnique.edu> +# +# License: MIT License + +from .backend import get_backend +from .optim import cg +import numpy as np + +__all__ = ['weak_optimal_transport'] + + +def weak_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, G0=None, **kwargs): + r"""Solves the weak optimal transport problem between two empirical distributions + + + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad \|X_a-diag(1/a)\gammaX_b\|_F^2 + + s.t. \ \gamma \mathbf{1} = \mathbf{a} + + \gamma^T \mathbf{1} = \mathbf{b} + + \gamma \geq 0 + + where : + + - :math:`X_a` :math:`X_b` are the sample matrices. + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights + + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + + Uses the conditional gradient algorithm to solve the problem proposed + in :ref:`[39] <references-weak>`. + + Parameters + ---------- + Xa : (ns,d) array-like, float + Source samples + Xb : (nt,d) array-like, float + Target samples + a : (ns,) array-like, float + Source histogram (uniform weight if empty list) + b : (nt,) array-like, float + Target histogram (uniform weight if empty list)) + numItermax : int, optional + Max number of iterations + numItermaxEmd : int, optional + Max number of iterations for emd + stopThr : float, optional + Stop threshold on the relative variation (>0) + stopThr2 : float, optional + Stop threshold on the absolute variation (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma: array-like, shape (ns, nt) + Optimal transportation matrix for the given + parameters + log: dict, optional + If input log is true, a dictionary containing the + cost and dual variables and exit status + + + .. _references-weak: + References + ---------- + .. [39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). + Kantorovich duality for general transport costs and applications. + Journal of Functional Analysis, 273(11), 3327-3405. + + See Also + -------- + ot.bregman.sinkhorn : Entropic regularized OT + ot.optim.cg : General regularized OT + """ + + nx = get_backend(Xa, Xb) + + Xa2 = nx.to_numpy(Xa) + Xb2 = nx.to_numpy(Xb) + + if a is None: + a2 = np.ones((Xa.shape[0])) / Xa.shape[0] + else: + a2 = nx.to_numpy(a) + if b is None: + b2 = np.ones((Xb.shape[0])) / Xb.shape[0] + else: + b2 = nx.to_numpy(b) + + # init uniform + if G0 is None: + T0 = a2[:, None] * b2[None, :] + else: + T0 = nx.to_numpy(G0) + + # weak OT loss + def f(T): + return np.dot(a2, np.sum((Xa2 - np.dot(T, Xb2) / a2[:, None])**2, 1)) + + # weak OT gradient + def df(T): + return -2 * np.dot(Xa2 - np.dot(T, Xb2) / a2[:, None], Xb2.T) + + # solve with conditional gradient and return solution + if log: + res, log = cg(a2, b2, 0, 1, f, df, T0, log=log, verbose=verbose, **kwargs) + log['u'] = nx.from_numpy(log['u'], type_as=Xa) + log['v'] = nx.from_numpy(log['v'], type_as=Xb) + return nx.from_numpy(res, type_as=Xa), log + else: + return nx.from_numpy(cg(a2, b2, 0, 1, f, df, T0, log=log, verbose=verbose, **kwargs), type_as=Xa) |