diff options
-rw-r--r-- | README.md | 4 | ||||
-rw-r--r-- | RELEASES.md | 1 | ||||
-rw-r--r-- | docs/source/all.rst | 1 | ||||
-rw-r--r-- | examples/others/plot_factored_coupling.py | 86 | ||||
-rw-r--r-- | ot/__init__.py | 5 | ||||
-rw-r--r-- | ot/factored.py | 145 | ||||
-rw-r--r-- | ot/plot.py | 7 | ||||
-rw-r--r-- | test/test_factored.py | 56 |
8 files changed, 303 insertions, 2 deletions
@@ -305,4 +305,6 @@ Conference on Machine Learning, PMLR 119:4692-4701, 2020 [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Conference on Machine Learning (ICML), 2021. -[39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). [Kantorovich duality for general transport costs and applications](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.712.1825&rep=rep1&type=pdf). Journal of Functional Analysis, 273(11), 3327-3405.
\ No newline at end of file +[39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). [Kantorovich duality for general transport costs and applications](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.712.1825&rep=rep1&type=pdf). Journal of Functional Analysis, 273(11), 3327-3405. + +[40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, G., & Weed, J. (2019, April). [Statistical optimal transport via factored couplings](http://proceedings.mlr.press/v89/forrow19a/forrow19a.pdf). In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2454-2465). PMLR.
\ No newline at end of file diff --git a/RELEASES.md b/RELEASES.md index 86b401a..c2bd0d1 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -5,6 +5,7 @@ #### New features +- Implementation of factored OT with emd and sinkhorn (PR #358). - A brand new logo for POT (PR #357) - Better list of related examples in quick start guide with `minigallery` (PR #334). - Add optional log-domain Sinkhorn implementation in WDA to support smaller values diff --git a/docs/source/all.rst b/docs/source/all.rst index 76d2ff5..3f7d029 100644 --- a/docs/source/all.rst +++ b/docs/source/all.rst @@ -29,6 +29,7 @@ API and modules partial sliced weak + factored .. autosummary:: :toctree: ../modules/generated/ diff --git a/examples/others/plot_factored_coupling.py b/examples/others/plot_factored_coupling.py new file mode 100644 index 0000000..b5b1c9f --- /dev/null +++ b/examples/others/plot_factored_coupling.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- +""" +========================================== +Optimal transport with factored couplings +========================================== + +Illustration of the factored coupling OT between 2D empirical distributions + +""" + +# Author: Remi Flamary <remi.flamary@polytechnique.edu> +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 2 + +import numpy as np +import matplotlib.pylab as pl +import ot +import ot.plot + +# %% +# Generate data an plot it +# ------------------------ + +# parameters and data generation + +np.random.seed(42) + +n = 100 # nb samples + +xs = np.random.rand(n, 2) - .5 + +xs = xs + np.sign(xs) + +xt = np.random.rand(n, 2) - .5 + +a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples + +#%% plot samples + +pl.figure(1) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.legend(loc=0) +pl.title('Source and target distributions') + + +# %% +# Compute Factore OT and exact OT solutions +# -------------------------------------- + +#%% EMD +M = ot.dist(xs, xt) +G0 = ot.emd(a, b, M) + +#%% factored OT OT + +Ga, Gb, xb = ot.factored_optimal_transport(xs, xt, a, b, r=4) + + +# %% +# Plot factored OT and exact OT solutions +# -------------------------------------- + +pl.figure(2, (14, 4)) + +pl.subplot(1, 3, 1) +ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.2, .2, .2], alpha=0.1) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.title('Exact OT with samples') + +pl.subplot(1, 3, 2) +ot.plot.plot2D_samples_mat(xs, xb, Ga, c=[.6, .6, .9], alpha=0.5) +ot.plot.plot2D_samples_mat(xb, xt, Gb, c=[.9, .6, .6], alpha=0.5) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.plot(xb[:, 0], xb[:, 1], 'og', label='Template samples') +pl.title('Factored OT with template samples') + +pl.subplot(1, 3, 3) +ot.plot.plot2D_samples_mat(xs, xt, Ga.dot(Gb), c=[.2, .2, .2], alpha=0.1) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.title('Factored OT low rank OT plan') diff --git a/ot/__init__.py b/ot/__init__.py index bda7a35..c5e1967 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -33,6 +33,7 @@ from . import partial from . import backend from . import regpath from . import weak +from . import factored # OT functions from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d @@ -44,6 +45,9 @@ from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance from .gromov import (gromov_wasserstein, gromov_wasserstein2, gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) from .weak import weak_optimal_transport +from .factored import factored_optimal_transport + + # utils functions from .utils import dist, unif, tic, toc, toq @@ -57,4 +61,5 @@ __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2', 'max_sliced_wasserstein_distance', 'weak_optimal_transport', + 'factored_optimal_transport', 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath'] diff --git a/ot/factored.py b/ot/factored.py new file mode 100644 index 0000000..abc2445 --- /dev/null +++ b/ot/factored.py @@ -0,0 +1,145 @@ +""" +Factored OT solvers (low rank, cost or OT plan) +""" + +# Author: Remi Flamary <remi.flamary@polytehnique.edu> +# +# License: MIT License + +from .backend import get_backend +from .utils import dist +from .lp import emd +from .bregman import sinkhorn + +__all__ = ['factored_optimal_transport'] + + +def factored_optimal_transport(Xa, Xb, a=None, b=None, reg=0.0, r=100, X0=None, stopThr=1e-7, numItermax=100, verbose=False, log=False, **kwargs): + r"""Solves factored OT problem and return OT plans and intermediate distribution + + This function solve the following OT problem [40]_ + + .. math:: + \mathop{\arg \min}_\mu \quad W_2^2(\mu_a,\mu)+ W_2^2(\mu,\mu_b) + + where : + + - :math:`\mu_a` and :math:`\mu_b` are empirical distributions. + - :math:`\mu` is an empirical distribution with r samples + + And returns the two OT plans between + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + + Uses the conditional gradient algorithm to solve the problem proposed in + :ref:`[39] <references-weak>`. + + Parameters + ---------- + Xa : (ns,d) array-like, float + Source samples + Xb : (nt,d) array-like, float + Target samples + a : (ns,) array-like, float + Source histogram (uniform weight if empty list) + b : (nt,) array-like, float + Target histogram (uniform weight if empty list)) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on the relative variation (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + Ga: array-like, shape (ns, r) + Optimal transportation matrix between source and the intermediate + distribution + Gb: array-like, shape (r, nt) + Optimal transportation matrix between the intermediate and target + distribution + X: array-like, shape (r, d) + Support of the intermediate distribution + log: dict, optional + If input log is true, a dictionary containing the cost and dual + variables and exit status + + + .. _references-factored: + References + ---------- + .. [40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, + G., & Weed, J. (2019, April). Statistical optimal transport via factored + couplings. In The 22nd International Conference on Artificial + Intelligence and Statistics (pp. 2454-2465). PMLR. + + See Also + -------- + ot.bregman.sinkhorn : Entropic regularized OT ot.optim.cg : General + regularized OT + """ + + nx = get_backend(Xa, Xb) + + n_a = Xa.shape[0] + n_b = Xb.shape[0] + d = Xa.shape[1] + + if a is None: + a = nx.ones((n_a), type_as=Xa) / n_a + if b is None: + b = nx.ones((n_b), type_as=Xb) / n_b + + if X0 is None: + X = nx.randn(r, d, type_as=Xa) + else: + X = X0 + + w = nx.ones(r, type_as=Xa) / r + + def solve_ot(X1, X2, w1, w2): + M = dist(X1, X2) + if reg > 0: + G, log = sinkhorn(w1, w2, M, reg, log=True, **kwargs) + log['cost'] = nx.sum(G * M) + return G, log + else: + return emd(w1, w2, M, log=True, **kwargs) + + norm_delta = [] + + # solve the barycenter + for i in range(numItermax): + + old_X = X + + # solve OT with template + Ga, loga = solve_ot(Xa, X, a, w) + Gb, logb = solve_ot(X, Xb, w, b) + + X = 0.5 * (nx.dot(Ga.T, Xa) + nx.dot(Gb, Xb)) * r + + delta = nx.norm(X - old_X) + if delta < stopThr: + break + if log: + norm_delta.append(delta) + + if log: + log_dic = {'delta_iter': norm_delta, + 'ua': loga['u'], + 'va': loga['v'], + 'ub': logb['u'], + 'vb': logb['v'], + 'costa': loga['cost'], + 'costb': logb['cost'], + } + return Ga, Gb, X, log_dic + + return Ga, Gb, X @@ -85,8 +85,13 @@ def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs): if ('color' not in kwargs) and ('c' not in kwargs): kwargs['color'] = 'k' mx = G.max() + if 'alpha' in kwargs: + scale = kwargs['alpha'] + del kwargs['alpha'] + else: + scale = 1 for i in range(xs.shape[0]): for j in range(xt.shape[0]): if G[i, j] / mx > thr: pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], - alpha=G[i, j] / mx, **kwargs) + alpha=G[i, j] / mx * scale, **kwargs) diff --git a/test/test_factored.py b/test/test_factored.py new file mode 100644 index 0000000..fd2fd01 --- /dev/null +++ b/test/test_factored.py @@ -0,0 +1,56 @@ +"""Tests for main module ot.weak """ + +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License + +import ot +import numpy as np + + +def test_factored_ot(): + # test weak ot solver and identity stationary point + n = 50 + rng = np.random.RandomState(0) + + xs = rng.randn(n, 2) + xt = rng.randn(n, 2) + u = ot.utils.unif(n) + + Ga, Gb, X, log = ot.factored_optimal_transport(xs, xt, u, u, r=10, log=True) + + # check constraints + np.testing.assert_allclose(u, Ga.sum(1)) + np.testing.assert_allclose(u, Gb.sum(0)) + + Ga, Gb, X, log = ot.factored_optimal_transport(xs, xt, u, u, reg=1, r=10, log=True) + + # check constraints + np.testing.assert_allclose(u, Ga.sum(1)) + np.testing.assert_allclose(u, Gb.sum(0)) + + +def test_factored_ot_backends(nx): + # test weak ot solver for different backends + n = 50 + rng = np.random.RandomState(0) + + xs = rng.randn(n, 2) + xt = rng.randn(n, 2) + u = ot.utils.unif(n) + + xs2 = nx.from_numpy(xs) + xt2 = nx.from_numpy(xt) + u2 = nx.from_numpy(u) + + Ga2, Gb2, X2 = ot.factored_optimal_transport(xs2, xt2, u2, u2, r=10) + + # check constraints + np.testing.assert_allclose(u, nx.to_numpy(Ga2).sum(1)) + np.testing.assert_allclose(u, nx.to_numpy(Gb2).sum(0)) + + Ga2, Gb2, X2 = ot.factored_optimal_transport(xs2, xt2, reg=1, r=10, X0=X2) + + # check constraints + np.testing.assert_allclose(u, nx.to_numpy(Ga2).sum(1)) + np.testing.assert_allclose(u, nx.to_numpy(Gb2).sum(0)) |