diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2022-02-02 11:53:12 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-02-02 11:53:12 +0100 |
commit | a5e0f0d40d5046a6639924347ef97e2ac80ad0c9 (patch) | |
tree | dcd35e851ec2cc3f52eedbfa58fb6970664135c9 /ot | |
parent | 71a57c68ea9eb2bc948c4dd1cce9928f34bf20e8 (diff) |
[MRG] Add weak OT solver (#341)
* add info in release file
* update tests
* pep8
* add weak OT example
* update plot in doc
* correction ewample with empirical sinkhorn
* better thumbnail
* comment from review
* update documenation
Diffstat (limited to 'ot')
-rw-r--r-- | ot/__init__.py | 5 | ||||
-rw-r--r-- | ot/gromov.py | 16 | ||||
-rw-r--r-- | ot/lp/__init__.py | 9 | ||||
-rw-r--r-- | ot/lp/cvx.py | 1 | ||||
-rw-r--r-- | ot/utils.py | 12 | ||||
-rw-r--r-- | ot/weak.py | 124 |
6 files changed, 159 insertions, 8 deletions
diff --git a/ot/__init__.py b/ot/__init__.py index 1ea7403..7253318 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -36,6 +36,7 @@ from . import unbalanced from . import partial from . import backend from . import regpath +from . import weak # OT functions from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d @@ -46,7 +47,7 @@ from .da import sinkhorn_lpl1_mm from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance from .gromov import (gromov_wasserstein, gromov_wasserstein2, gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) - +from .weak import weak_optimal_transport # utils functions from .utils import dist, unif, tic, toc, toq @@ -59,5 +60,5 @@ __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'sinkhorn_unbalanced', 'barycenter_unbalanced', 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2', - 'max_sliced_wasserstein_distance', + 'max_sliced_wasserstein_distance', 'weak_optimal_transport', 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath'] diff --git a/ot/gromov.py b/ot/gromov.py index 6544260..b7e7949 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -338,6 +338,10 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F - :math:`\mathbf{q}`: distribution in the target space
- `L`: loss function to account for the misfit between the similarity matrices
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
Parameters
----------
C1 : array-like, shape (ns, ns)
@@ -436,6 +440,10 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo= Note that when using backends, this loss function is differentiable wrt the
marices and weights for quadratic loss using the gradients from [38]_.
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
Parameters
----------
C1 : array-like, shape (ns, ns)
@@ -545,6 +553,10 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1)
- `L` is a loss function to account for the misfit between the similarity matrices
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] <references-fused-gromov-wasserstein>`
Parameters
@@ -645,6 +657,10 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 The algorithm used for solving the problem is conditional gradient as
discussed in :ref:`[24] <references-fused-gromov-wasserstein2>`
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
Note that when using backends, this loss function is differentiable wrt the
marices and weights for quadratic loss using the gradients from [38]_.
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 2ff7c1f..d9b6fa9 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -26,6 +26,8 @@ from ..utils import dist, list_to_array from ..utils import parmap from ..backend import get_backend + + __all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted', 'emd_1d', 'emd2_1d', 'wasserstein_1d'] @@ -220,7 +222,8 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): format .. note:: This function is backend-compatible and will work on arrays - from all compatible backends. + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. Uses the algorithm proposed in :ref:`[1] <references-emd>`. @@ -358,7 +361,8 @@ def emd2(a, b, M, processes=1, - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights .. note:: This function is backend-compatible and will work on arrays - from all compatible backends. + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. Uses the algorithm proposed in :ref:`[1] <references-emd2>`. @@ -622,3 +626,4 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None return X, log_dict else: return X + diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py index 869d450..fbf3c0e 100644 --- a/ot/lp/cvx.py +++ b/ot/lp/cvx.py @@ -11,7 +11,6 @@ import numpy as np import scipy as sp import scipy.sparse as sps - try: import cvxopt from cvxopt import solvers, matrix, spmatrix diff --git a/ot/utils.py b/ot/utils.py index e6c93c8..725ca00 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -116,7 +116,7 @@ def proj_simplex(v, z=1): return w -def unif(n): +def unif(n, type_as=None): r""" Return a uniform histogram of length `n` (simplex). @@ -124,13 +124,19 @@ def unif(n): ---------- n : int number of bins in the histogram + type_as : array_like + array of the same type of the expected output (numpy/pytorch/jax) Returns ------- - h : np.array (`n`,) + h : array_like (`n`,) histogram of length `n` such that :math:`\forall i, \mathbf{h}_i = \frac{1}{n}` """ - return np.ones((n,)) / n + if type_as is None: + return np.ones((n,)) / n + else: + nx = get_backend(type_as) + return nx.ones((n,)) / n def clean_zeros(a, b, M): diff --git a/ot/weak.py b/ot/weak.py new file mode 100644 index 0000000..f7d5b23 --- /dev/null +++ b/ot/weak.py @@ -0,0 +1,124 @@ +""" +Weak optimal ransport solvers +""" + +# Author: Remi Flamary <remi.flamary@polytehnique.edu> +# +# License: MIT License + +from .backend import get_backend +from .optim import cg +import numpy as np + +__all__ = ['weak_optimal_transport'] + + +def weak_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, G0=None, **kwargs): + r"""Solves the weak optimal transport problem between two empirical distributions + + + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad \|X_a-diag(1/a)\gammaX_b\|_F^2 + + s.t. \ \gamma \mathbf{1} = \mathbf{a} + + \gamma^T \mathbf{1} = \mathbf{b} + + \gamma \geq 0 + + where : + + - :math:`X_a` :math:`X_b` are the sample matrices. + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights + + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + + Uses the conditional gradient algorithm to solve the problem proposed + in :ref:`[39] <references-weak>`. + + Parameters + ---------- + Xa : (ns,d) array-like, float + Source samples + Xb : (nt,d) array-like, float + Target samples + a : (ns,) array-like, float + Source histogram (uniform weight if empty list) + b : (nt,) array-like, float + Target histogram (uniform weight if empty list)) + numItermax : int, optional + Max number of iterations + numItermaxEmd : int, optional + Max number of iterations for emd + stopThr : float, optional + Stop threshold on the relative variation (>0) + stopThr2 : float, optional + Stop threshold on the absolute variation (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma: array-like, shape (ns, nt) + Optimal transportation matrix for the given + parameters + log: dict, optional + If input log is true, a dictionary containing the + cost and dual variables and exit status + + + .. _references-weak: + References + ---------- + .. [39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). + Kantorovich duality for general transport costs and applications. + Journal of Functional Analysis, 273(11), 3327-3405. + + See Also + -------- + ot.bregman.sinkhorn : Entropic regularized OT + ot.optim.cg : General regularized OT + """ + + nx = get_backend(Xa, Xb) + + Xa2 = nx.to_numpy(Xa) + Xb2 = nx.to_numpy(Xb) + + if a is None: + a2 = np.ones((Xa.shape[0])) / Xa.shape[0] + else: + a2 = nx.to_numpy(a) + if b is None: + b2 = np.ones((Xb.shape[0])) / Xb.shape[0] + else: + b2 = nx.to_numpy(b) + + # init uniform + if G0 is None: + T0 = a2[:, None] * b2[None, :] + else: + T0 = nx.to_numpy(G0) + + # weak OT loss + def f(T): + return np.dot(a2, np.sum((Xa2 - np.dot(T, Xb2) / a2[:, None])**2, 1)) + + # weak OT gradient + def df(T): + return -2 * np.dot(Xa2 - np.dot(T, Xb2) / a2[:, None], Xb2.T) + + # solve with conditional gradient and return solution + if log: + res, log = cg(a2, b2, 0, 1, f, df, T0, log=log, verbose=verbose, **kwargs) + log['u'] = nx.from_numpy(log['u'], type_as=Xa) + log['v'] = nx.from_numpy(log['v'], type_as=Xb) + return nx.from_numpy(res, type_as=Xa), log + else: + return nx.from_numpy(cg(a2, b2, 0, 1, f, df, T0, log=log, verbose=verbose, **kwargs), type_as=Xa) |