diff options
-rw-r--r-- | README.md | 3 | ||||
-rw-r--r-- | RELEASES.md | 8 | ||||
-rw-r--r-- | docs/source/all.rst | 1 | ||||
-rw-r--r-- | examples/others/plot_WeakOT_VS_OT.py | 98 | ||||
-rw-r--r-- | examples/plot_OT_2D_samples.py | 5 | ||||
-rw-r--r-- | ot/__init__.py | 5 | ||||
-rw-r--r-- | ot/gromov.py | 16 | ||||
-rw-r--r-- | ot/lp/__init__.py | 9 | ||||
-rw-r--r-- | ot/lp/cvx.py | 1 | ||||
-rw-r--r-- | ot/utils.py | 12 | ||||
-rw-r--r-- | ot/weak.py | 124 | ||||
-rw-r--r-- | test/test_bregman.py | 13 | ||||
-rw-r--r-- | test/test_ot.py | 2 | ||||
-rw-r--r-- | test/test_utils.py | 18 | ||||
-rw-r--r-- | test/test_weak.py | 54 |
15 files changed, 343 insertions, 26 deletions
@@ -25,6 +25,7 @@ POT provides the following generic OT solvers (links to examples): * Sinkhorn divergence [23] and entropic regularization OT from empirical data. * Debiased Sinkhorn barycenters [Sinkhorn divergence barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_debiased_barycenter.html) [37] * [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17]. +* Weak OT solver between empirical distributions [39] * Non regularized [Wasserstein barycenters [16] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html)) with LP solver (only small scale). * [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]), differentiable using gradients from * [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) [24] @@ -301,3 +302,5 @@ Conference on Machine Learning, PMLR 119:4692-4701, 2020 [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Conference on Machine Learning (ICML), 2021. + +[39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). [Kantorovich duality for general transport costs and applications](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.712.1825&rep=rep1&type=pdf). Journal of Functional Analysis, 273(11), 3327-3405.
\ No newline at end of file diff --git a/RELEASES.md b/RELEASES.md index 94c853b..4d05582 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -5,10 +5,12 @@ #### New features -- Better list of related examples in quick start guide with `minigallery` (PR #334) +- Better list of related examples in quick start guide with `minigallery` (PR #334). - Add optional log-domain Sinkhorn implementation in WDA to support smaller values - of the regularization parameter (PR #336) -- Backend implementation for `ot.lp.free_support_barycenter` (PR #340) + of the regularization parameter (PR #336). +- Backend implementation for `ot.lp.free_support_barycenter` (PR #340). +- Add weak OT solver + example (PR #341). + #### Closed issues diff --git a/docs/source/all.rst b/docs/source/all.rst index 7f85a91..76d2ff5 100644 --- a/docs/source/all.rst +++ b/docs/source/all.rst @@ -28,6 +28,7 @@ API and modules unbalanced partial sliced + weak .. autosummary:: :toctree: ../modules/generated/ diff --git a/examples/others/plot_WeakOT_VS_OT.py b/examples/others/plot_WeakOT_VS_OT.py new file mode 100644 index 0000000..a29c875 --- /dev/null +++ b/examples/others/plot_WeakOT_VS_OT.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +""" +==================================================== +Weak Optimal Transport VS exact Optimal Transport +==================================================== + +Illustration of 2D optimal transport between distributions that are weighted +sum of diracs. The OT matrix is plotted with the samples. + +""" + +# Author: Remi Flamary <remi.flamary@polytechnique.edu> +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 4 + +import numpy as np +import matplotlib.pylab as pl +import ot +import ot.plot + +############################################################################## +# Generate data an plot it +# ------------------------ + +#%% parameters and data generation + +n = 50 # nb samples + +mu_s = np.array([0, 0]) +cov_s = np.array([[1, 0], [0, 1]]) + +mu_t = np.array([4, 4]) +cov_t = np.array([[1, -.8], [-.8, 1]]) + +xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s) +xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t) + +a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples + +# loss matrix +M = ot.dist(xs, xt) +M /= M.max() + +#%% plot samples + +pl.figure(1) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.legend(loc=0) +pl.title('Source and target distributions') + +pl.figure(2) +pl.imshow(M, interpolation='nearest') +pl.title('Cost matrix M') + + +############################################################################## +# Compute Weak OT and exact OT solutions +# -------------------------------------- + +#%% EMD + +G0 = ot.emd(a, b, M) + +#%% Weak OT + +Gweak = ot.weak_optimal_transport(xs, xt, a, b) + + +############################################################################## +# Plot weak OT and exact OT solutions +# -------------------------------------- + +pl.figure(3, (8, 5)) + +pl.subplot(1, 2, 1) +pl.imshow(G0, interpolation='nearest') +pl.title('OT matrix') + +pl.subplot(1, 2, 2) +pl.imshow(Gweak, interpolation='nearest') +pl.title('Weak OT matrix') + +pl.figure(4, (8, 5)) + +pl.subplot(1, 2, 1) +ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.5, .5, 1]) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.title('OT matrix with samples') + +pl.subplot(1, 2, 2) +ot.plot.plot2D_samples_mat(xs, xt, Gweak, c=[.5, .5, 1]) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.title('Weak OT matrix with samples') diff --git a/examples/plot_OT_2D_samples.py b/examples/plot_OT_2D_samples.py index af1bc12..c3a7cd8 100644 --- a/examples/plot_OT_2D_samples.py +++ b/examples/plot_OT_2D_samples.py @@ -42,7 +42,6 @@ a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples # loss matrix M = ot.dist(xs, xt) -M /= M.max() ############################################################################## # Plot data @@ -87,7 +86,7 @@ pl.title('OT matrix with samples') #%% sinkhorn # reg term -lambd = 1e-3 +lambd = 1e-1 Gs = ot.sinkhorn(a, b, M, lambd) @@ -112,7 +111,7 @@ pl.show() #%% sinkhorn # reg term -lambd = 1e-3 +lambd = 1e-1 Ges = ot.bregman.empirical_sinkhorn(xs, xt, lambd) diff --git a/ot/__init__.py b/ot/__init__.py index 1ea7403..7253318 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -36,6 +36,7 @@ from . import unbalanced from . import partial from . import backend from . import regpath +from . import weak # OT functions from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d @@ -46,7 +47,7 @@ from .da import sinkhorn_lpl1_mm from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance from .gromov import (gromov_wasserstein, gromov_wasserstein2, gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) - +from .weak import weak_optimal_transport # utils functions from .utils import dist, unif, tic, toc, toq @@ -59,5 +60,5 @@ __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'sinkhorn_unbalanced', 'barycenter_unbalanced', 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2', - 'max_sliced_wasserstein_distance', + 'max_sliced_wasserstein_distance', 'weak_optimal_transport', 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath'] diff --git a/ot/gromov.py b/ot/gromov.py index 6544260..b7e7949 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -338,6 +338,10 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F - :math:`\mathbf{q}`: distribution in the target space
- `L`: loss function to account for the misfit between the similarity matrices
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
Parameters
----------
C1 : array-like, shape (ns, ns)
@@ -436,6 +440,10 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo= Note that when using backends, this loss function is differentiable wrt the
marices and weights for quadratic loss using the gradients from [38]_.
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
Parameters
----------
C1 : array-like, shape (ns, ns)
@@ -545,6 +553,10 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1)
- `L` is a loss function to account for the misfit between the similarity matrices
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] <references-fused-gromov-wasserstein>`
Parameters
@@ -645,6 +657,10 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 The algorithm used for solving the problem is conditional gradient as
discussed in :ref:`[24] <references-fused-gromov-wasserstein2>`
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
Note that when using backends, this loss function is differentiable wrt the
marices and weights for quadratic loss using the gradients from [38]_.
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 2ff7c1f..d9b6fa9 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -26,6 +26,8 @@ from ..utils import dist, list_to_array from ..utils import parmap from ..backend import get_backend + + __all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted', 'emd_1d', 'emd2_1d', 'wasserstein_1d'] @@ -220,7 +222,8 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): format .. note:: This function is backend-compatible and will work on arrays - from all compatible backends. + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. Uses the algorithm proposed in :ref:`[1] <references-emd>`. @@ -358,7 +361,8 @@ def emd2(a, b, M, processes=1, - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights .. note:: This function is backend-compatible and will work on arrays - from all compatible backends. + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. Uses the algorithm proposed in :ref:`[1] <references-emd2>`. @@ -622,3 +626,4 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None return X, log_dict else: return X + diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py index 869d450..fbf3c0e 100644 --- a/ot/lp/cvx.py +++ b/ot/lp/cvx.py @@ -11,7 +11,6 @@ import numpy as np import scipy as sp import scipy.sparse as sps - try: import cvxopt from cvxopt import solvers, matrix, spmatrix diff --git a/ot/utils.py b/ot/utils.py index e6c93c8..725ca00 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -116,7 +116,7 @@ def proj_simplex(v, z=1): return w -def unif(n): +def unif(n, type_as=None): r""" Return a uniform histogram of length `n` (simplex). @@ -124,13 +124,19 @@ def unif(n): ---------- n : int number of bins in the histogram + type_as : array_like + array of the same type of the expected output (numpy/pytorch/jax) Returns ------- - h : np.array (`n`,) + h : array_like (`n`,) histogram of length `n` such that :math:`\forall i, \mathbf{h}_i = \frac{1}{n}` """ - return np.ones((n,)) / n + if type_as is None: + return np.ones((n,)) / n + else: + nx = get_backend(type_as) + return nx.ones((n,)) / n def clean_zeros(a, b, M): diff --git a/ot/weak.py b/ot/weak.py new file mode 100644 index 0000000..f7d5b23 --- /dev/null +++ b/ot/weak.py @@ -0,0 +1,124 @@ +""" +Weak optimal ransport solvers +""" + +# Author: Remi Flamary <remi.flamary@polytehnique.edu> +# +# License: MIT License + +from .backend import get_backend +from .optim import cg +import numpy as np + +__all__ = ['weak_optimal_transport'] + + +def weak_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, G0=None, **kwargs): + r"""Solves the weak optimal transport problem between two empirical distributions + + + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad \|X_a-diag(1/a)\gammaX_b\|_F^2 + + s.t. \ \gamma \mathbf{1} = \mathbf{a} + + \gamma^T \mathbf{1} = \mathbf{b} + + \gamma \geq 0 + + where : + + - :math:`X_a` :math:`X_b` are the sample matrices. + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights + + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + + Uses the conditional gradient algorithm to solve the problem proposed + in :ref:`[39] <references-weak>`. + + Parameters + ---------- + Xa : (ns,d) array-like, float + Source samples + Xb : (nt,d) array-like, float + Target samples + a : (ns,) array-like, float + Source histogram (uniform weight if empty list) + b : (nt,) array-like, float + Target histogram (uniform weight if empty list)) + numItermax : int, optional + Max number of iterations + numItermaxEmd : int, optional + Max number of iterations for emd + stopThr : float, optional + Stop threshold on the relative variation (>0) + stopThr2 : float, optional + Stop threshold on the absolute variation (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma: array-like, shape (ns, nt) + Optimal transportation matrix for the given + parameters + log: dict, optional + If input log is true, a dictionary containing the + cost and dual variables and exit status + + + .. _references-weak: + References + ---------- + .. [39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). + Kantorovich duality for general transport costs and applications. + Journal of Functional Analysis, 273(11), 3327-3405. + + See Also + -------- + ot.bregman.sinkhorn : Entropic regularized OT + ot.optim.cg : General regularized OT + """ + + nx = get_backend(Xa, Xb) + + Xa2 = nx.to_numpy(Xa) + Xb2 = nx.to_numpy(Xb) + + if a is None: + a2 = np.ones((Xa.shape[0])) / Xa.shape[0] + else: + a2 = nx.to_numpy(a) + if b is None: + b2 = np.ones((Xb.shape[0])) / Xb.shape[0] + else: + b2 = nx.to_numpy(b) + + # init uniform + if G0 is None: + T0 = a2[:, None] * b2[None, :] + else: + T0 = nx.to_numpy(G0) + + # weak OT loss + def f(T): + return np.dot(a2, np.sum((Xa2 - np.dot(T, Xb2) / a2[:, None])**2, 1)) + + # weak OT gradient + def df(T): + return -2 * np.dot(Xa2 - np.dot(T, Xb2) / a2[:, None], Xb2.T) + + # solve with conditional gradient and return solution + if log: + res, log = cg(a2, b2, 0, 1, f, df, T0, log=log, verbose=verbose, **kwargs) + log['u'] = nx.from_numpy(log['u'], type_as=Xa) + log['v'] = nx.from_numpy(log['v'], type_as=Xb) + return nx.from_numpy(res, type_as=Xa), log + else: + return nx.from_numpy(cg(a2, b2, 0, 1, f, df, T0, log=log, verbose=verbose, **kwargs), type_as=Xa) diff --git a/test/test_bregman.py b/test/test_bregman.py index 6e90aa4..1419f9b 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -60,7 +60,7 @@ def test_convergence_warning(method): ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1) -def test_not_impemented_method(): +def test_not_implemented_method(): # test sinkhorn w = 10 n = w ** 2 @@ -635,7 +635,7 @@ def test_wasserstein_bary_2d(nx, method): with pytest.raises(NotImplementedError): ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method) else: - bary_wass_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method) + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method, verbose=True, log=True) bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)) np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) @@ -667,7 +667,7 @@ def test_wasserstein_bary_2d_debiased(nx, method): with pytest.raises(NotImplementedError): ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method) else: - bary_wass_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method) + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method, verbose=True, log=True) bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)) np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) @@ -940,14 +940,11 @@ def test_screenkhorn(nx): bb = nx.from_numpy(b) M_nx = nx.from_numpy(M, type_as=ab) - # np sinkhorn - G_sink_np = ot.sinkhorn(a, b, M, 1e-03) # sinkhorn - G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-03)) + G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-1)) # screenkhorn - G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-03, uniform=True, verbose=True)) + G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-1, uniform=True, verbose=True)) # check marginals - np.testing.assert_allclose(G_sink_np, G_sink) np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02) np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02) diff --git a/test/test_ot.py b/test/test_ot.py index e8e2d97..3e2d845 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -232,7 +232,7 @@ def test_emd2_multi(): # Gaussian distributions a = gauss(n, m=20, s=5) # m= mean, s= std - ls = np.arange(20, 500, 20) + ls = np.arange(20, 500, 100) nb = len(ls) b = np.zeros((n, nb)) for i in range(nb): diff --git a/test/test_utils.py b/test/test_utils.py index 8b23c22..5ad167b 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -62,12 +62,12 @@ def test_tic_toc(): import time ot.tic() - time.sleep(0.5) + time.sleep(0.1) t = ot.toc() t2 = ot.toq() # test timing - np.testing.assert_allclose(0.5, t, rtol=1e-1, atol=1e-1) + np.testing.assert_allclose(0.1, t, rtol=1e-1, atol=1e-1) # test toc vs toq np.testing.assert_allclose(t, t2, rtol=1e-1, atol=1e-1) @@ -94,10 +94,22 @@ def test_unif(): np.testing.assert_allclose(1, np.sum(u)) -def test_dist(): +def test_unif_backend(nx): n = 100 + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + u = ot.unif(n, type_as=tp) + + np.testing.assert_allclose(1, np.sum(nx.to_numpy(u)), atol=1e-6) + + +def test_dist(): + + n = 10 + rng = np.random.RandomState(0) x = rng.randn(n, 2) diff --git a/test/test_weak.py b/test/test_weak.py new file mode 100644 index 0000000..c4c3278 --- /dev/null +++ b/test/test_weak.py @@ -0,0 +1,54 @@ +"""Tests for main module ot.weak """ + +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License + +import ot +import numpy as np + + +def test_weak_ot(): + # test weak ot solver and identity stationary point + n = 50 + rng = np.random.RandomState(0) + + xs = rng.randn(n, 2) + xt = rng.randn(n, 2) + u = ot.utils.unif(n) + + G, log = ot.weak_optimal_transport(xs, xt, u, u, log=True) + + # check constraints + np.testing.assert_allclose(u, G.sum(1)) + np.testing.assert_allclose(u, G.sum(0)) + + # chaeck that identity is recovered + G = ot.weak_optimal_transport(xs, xs, G0=np.eye(n) / n) + + # check G is identity + np.testing.assert_allclose(G, np.eye(n) / n) + + # check constraints + np.testing.assert_allclose(u, G.sum(1)) + np.testing.assert_allclose(u, G.sum(0)) + + +def test_weak_ot_bakends(nx): + # test weak ot solver for different backends + n = 50 + rng = np.random.RandomState(0) + + xs = rng.randn(n, 2) + xt = rng.randn(n, 2) + u = ot.utils.unif(n) + + G = ot.weak_optimal_transport(xs, xt, u, u) + + xs2 = nx.from_numpy(xs) + xt2 = nx.from_numpy(xt) + u2 = nx.from_numpy(u) + + G2 = ot.weak_optimal_transport(xs2, xt2, u2, u2) + + np.testing.assert_allclose(nx.to_numpy(G2), G) |