summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2022-03-24 14:13:25 +0100
committerGitHub <noreply@github.com>2022-03-24 14:13:25 +0100
commit82452e0f5f6dae05c7a1cc384e7a1fb62ae7e0d5 (patch)
tree051871e3dc63e6bba1d0ecb1df6229796edd33bb
parent767171593f2a98a26b9a39bf110a45085e3b982e (diff)
[MRG] Add factored coupling (#358)
* add gfactored ot * pep8 and add doc * add exmaple for factotred OT * final number of PR * correct test on backends * remove useless loss * better tests
-rw-r--r--README.md4
-rw-r--r--RELEASES.md1
-rw-r--r--docs/source/all.rst1
-rw-r--r--examples/others/plot_factored_coupling.py86
-rw-r--r--ot/__init__.py5
-rw-r--r--ot/factored.py145
-rw-r--r--ot/plot.py7
-rw-r--r--test/test_factored.py56
8 files changed, 303 insertions, 2 deletions
diff --git a/README.md b/README.md
index c6bfd9c..ec5d221 100644
--- a/README.md
+++ b/README.md
@@ -305,4 +305,6 @@ Conference on Machine Learning, PMLR 119:4692-4701, 2020
[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph
Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Conference on Machine Learning (ICML), 2021.
-[39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). [Kantorovich duality for general transport costs and applications](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.712.1825&rep=rep1&type=pdf). Journal of Functional Analysis, 273(11), 3327-3405. \ No newline at end of file
+[39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). [Kantorovich duality for general transport costs and applications](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.712.1825&rep=rep1&type=pdf). Journal of Functional Analysis, 273(11), 3327-3405.
+
+[40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, G., & Weed, J. (2019, April). [Statistical optimal transport via factored couplings](http://proceedings.mlr.press/v89/forrow19a/forrow19a.pdf). In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2454-2465). PMLR. \ No newline at end of file
diff --git a/RELEASES.md b/RELEASES.md
index 86b401a..c2bd0d1 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -5,6 +5,7 @@
#### New features
+- Implementation of factored OT with emd and sinkhorn (PR #358).
- A brand new logo for POT (PR #357)
- Better list of related examples in quick start guide with `minigallery` (PR #334).
- Add optional log-domain Sinkhorn implementation in WDA to support smaller values
diff --git a/docs/source/all.rst b/docs/source/all.rst
index 76d2ff5..3f7d029 100644
--- a/docs/source/all.rst
+++ b/docs/source/all.rst
@@ -29,6 +29,7 @@ API and modules
partial
sliced
weak
+ factored
.. autosummary::
:toctree: ../modules/generated/
diff --git a/examples/others/plot_factored_coupling.py b/examples/others/plot_factored_coupling.py
new file mode 100644
index 0000000..b5b1c9f
--- /dev/null
+++ b/examples/others/plot_factored_coupling.py
@@ -0,0 +1,86 @@
+# -*- coding: utf-8 -*-
+"""
+==========================================
+Optimal transport with factored couplings
+==========================================
+
+Illustration of the factored coupling OT between 2D empirical distributions
+
+"""
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 2
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+import ot.plot
+
+# %%
+# Generate data an plot it
+# ------------------------
+
+# parameters and data generation
+
+np.random.seed(42)
+
+n = 100 # nb samples
+
+xs = np.random.rand(n, 2) - .5
+
+xs = xs + np.sign(xs)
+
+xt = np.random.rand(n, 2) - .5
+
+a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples
+
+#%% plot samples
+
+pl.figure(1)
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.legend(loc=0)
+pl.title('Source and target distributions')
+
+
+# %%
+# Compute Factore OT and exact OT solutions
+# --------------------------------------
+
+#%% EMD
+M = ot.dist(xs, xt)
+G0 = ot.emd(a, b, M)
+
+#%% factored OT OT
+
+Ga, Gb, xb = ot.factored_optimal_transport(xs, xt, a, b, r=4)
+
+
+# %%
+# Plot factored OT and exact OT solutions
+# --------------------------------------
+
+pl.figure(2, (14, 4))
+
+pl.subplot(1, 3, 1)
+ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.2, .2, .2], alpha=0.1)
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.title('Exact OT with samples')
+
+pl.subplot(1, 3, 2)
+ot.plot.plot2D_samples_mat(xs, xb, Ga, c=[.6, .6, .9], alpha=0.5)
+ot.plot.plot2D_samples_mat(xb, xt, Gb, c=[.9, .6, .6], alpha=0.5)
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.plot(xb[:, 0], xb[:, 1], 'og', label='Template samples')
+pl.title('Factored OT with template samples')
+
+pl.subplot(1, 3, 3)
+ot.plot.plot2D_samples_mat(xs, xt, Ga.dot(Gb), c=[.2, .2, .2], alpha=0.1)
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.title('Factored OT low rank OT plan')
diff --git a/ot/__init__.py b/ot/__init__.py
index bda7a35..c5e1967 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -33,6 +33,7 @@ from . import partial
from . import backend
from . import regpath
from . import weak
+from . import factored
# OT functions
from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
@@ -44,6 +45,9 @@ from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance
from .gromov import (gromov_wasserstein, gromov_wasserstein2,
gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2)
from .weak import weak_optimal_transport
+from .factored import factored_optimal_transport
+
+
# utils functions
from .utils import dist, unif, tic, toc, toq
@@ -57,4 +61,5 @@ __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils',
'sinkhorn_unbalanced2', 'sliced_wasserstein_distance',
'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2',
'max_sliced_wasserstein_distance', 'weak_optimal_transport',
+ 'factored_optimal_transport',
'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath']
diff --git a/ot/factored.py b/ot/factored.py
new file mode 100644
index 0000000..abc2445
--- /dev/null
+++ b/ot/factored.py
@@ -0,0 +1,145 @@
+"""
+Factored OT solvers (low rank, cost or OT plan)
+"""
+
+# Author: Remi Flamary <remi.flamary@polytehnique.edu>
+#
+# License: MIT License
+
+from .backend import get_backend
+from .utils import dist
+from .lp import emd
+from .bregman import sinkhorn
+
+__all__ = ['factored_optimal_transport']
+
+
+def factored_optimal_transport(Xa, Xb, a=None, b=None, reg=0.0, r=100, X0=None, stopThr=1e-7, numItermax=100, verbose=False, log=False, **kwargs):
+ r"""Solves factored OT problem and return OT plans and intermediate distribution
+
+ This function solve the following OT problem [40]_
+
+ .. math::
+ \mathop{\arg \min}_\mu \quad W_2^2(\mu_a,\mu)+ W_2^2(\mu,\mu_b)
+
+ where :
+
+ - :math:`\mu_a` and :math:`\mu_b` are empirical distributions.
+ - :math:`\mu` is an empirical distribution with r samples
+
+ And returns the two OT plans between
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
+ Uses the conditional gradient algorithm to solve the problem proposed in
+ :ref:`[39] <references-weak>`.
+
+ Parameters
+ ----------
+ Xa : (ns,d) array-like, float
+ Source samples
+ Xb : (nt,d) array-like, float
+ Target samples
+ a : (ns,) array-like, float
+ Source histogram (uniform weight if empty list)
+ b : (nt,) array-like, float
+ Target histogram (uniform weight if empty list))
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on the relative variation (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ Ga: array-like, shape (ns, r)
+ Optimal transportation matrix between source and the intermediate
+ distribution
+ Gb: array-like, shape (r, nt)
+ Optimal transportation matrix between the intermediate and target
+ distribution
+ X: array-like, shape (r, d)
+ Support of the intermediate distribution
+ log: dict, optional
+ If input log is true, a dictionary containing the cost and dual
+ variables and exit status
+
+
+ .. _references-factored:
+ References
+ ----------
+ .. [40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger,
+ G., & Weed, J. (2019, April). Statistical optimal transport via factored
+ couplings. In The 22nd International Conference on Artificial
+ Intelligence and Statistics (pp. 2454-2465). PMLR.
+
+ See Also
+ --------
+ ot.bregman.sinkhorn : Entropic regularized OT ot.optim.cg : General
+ regularized OT
+ """
+
+ nx = get_backend(Xa, Xb)
+
+ n_a = Xa.shape[0]
+ n_b = Xb.shape[0]
+ d = Xa.shape[1]
+
+ if a is None:
+ a = nx.ones((n_a), type_as=Xa) / n_a
+ if b is None:
+ b = nx.ones((n_b), type_as=Xb) / n_b
+
+ if X0 is None:
+ X = nx.randn(r, d, type_as=Xa)
+ else:
+ X = X0
+
+ w = nx.ones(r, type_as=Xa) / r
+
+ def solve_ot(X1, X2, w1, w2):
+ M = dist(X1, X2)
+ if reg > 0:
+ G, log = sinkhorn(w1, w2, M, reg, log=True, **kwargs)
+ log['cost'] = nx.sum(G * M)
+ return G, log
+ else:
+ return emd(w1, w2, M, log=True, **kwargs)
+
+ norm_delta = []
+
+ # solve the barycenter
+ for i in range(numItermax):
+
+ old_X = X
+
+ # solve OT with template
+ Ga, loga = solve_ot(Xa, X, a, w)
+ Gb, logb = solve_ot(X, Xb, w, b)
+
+ X = 0.5 * (nx.dot(Ga.T, Xa) + nx.dot(Gb, Xb)) * r
+
+ delta = nx.norm(X - old_X)
+ if delta < stopThr:
+ break
+ if log:
+ norm_delta.append(delta)
+
+ if log:
+ log_dic = {'delta_iter': norm_delta,
+ 'ua': loga['u'],
+ 'va': loga['v'],
+ 'ub': logb['u'],
+ 'vb': logb['v'],
+ 'costa': loga['cost'],
+ 'costb': logb['cost'],
+ }
+ return Ga, Gb, X, log_dic
+
+ return Ga, Gb, X
diff --git a/ot/plot.py b/ot/plot.py
index 2208c90..8ade2eb 100644
--- a/ot/plot.py
+++ b/ot/plot.py
@@ -85,8 +85,13 @@ def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs):
if ('color' not in kwargs) and ('c' not in kwargs):
kwargs['color'] = 'k'
mx = G.max()
+ if 'alpha' in kwargs:
+ scale = kwargs['alpha']
+ del kwargs['alpha']
+ else:
+ scale = 1
for i in range(xs.shape[0]):
for j in range(xt.shape[0]):
if G[i, j] / mx > thr:
pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]],
- alpha=G[i, j] / mx, **kwargs)
+ alpha=G[i, j] / mx * scale, **kwargs)
diff --git a/test/test_factored.py b/test/test_factored.py
new file mode 100644
index 0000000..fd2fd01
--- /dev/null
+++ b/test/test_factored.py
@@ -0,0 +1,56 @@
+"""Tests for main module ot.weak """
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import ot
+import numpy as np
+
+
+def test_factored_ot():
+ # test weak ot solver and identity stationary point
+ n = 50
+ rng = np.random.RandomState(0)
+
+ xs = rng.randn(n, 2)
+ xt = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ Ga, Gb, X, log = ot.factored_optimal_transport(xs, xt, u, u, r=10, log=True)
+
+ # check constraints
+ np.testing.assert_allclose(u, Ga.sum(1))
+ np.testing.assert_allclose(u, Gb.sum(0))
+
+ Ga, Gb, X, log = ot.factored_optimal_transport(xs, xt, u, u, reg=1, r=10, log=True)
+
+ # check constraints
+ np.testing.assert_allclose(u, Ga.sum(1))
+ np.testing.assert_allclose(u, Gb.sum(0))
+
+
+def test_factored_ot_backends(nx):
+ # test weak ot solver for different backends
+ n = 50
+ rng = np.random.RandomState(0)
+
+ xs = rng.randn(n, 2)
+ xt = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ xs2 = nx.from_numpy(xs)
+ xt2 = nx.from_numpy(xt)
+ u2 = nx.from_numpy(u)
+
+ Ga2, Gb2, X2 = ot.factored_optimal_transport(xs2, xt2, u2, u2, r=10)
+
+ # check constraints
+ np.testing.assert_allclose(u, nx.to_numpy(Ga2).sum(1))
+ np.testing.assert_allclose(u, nx.to_numpy(Gb2).sum(0))
+
+ Ga2, Gb2, X2 = ot.factored_optimal_transport(xs2, xt2, reg=1, r=10, X0=X2)
+
+ # check constraints
+ np.testing.assert_allclose(u, nx.to_numpy(Ga2).sum(1))
+ np.testing.assert_allclose(u, nx.to_numpy(Gb2).sum(0))