summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md4
-rw-r--r--data/duck.pngbin0 -> 5112 bytes
-rw-r--r--data/heart.pngbin0 -> 5225 bytes
-rw-r--r--data/redcross.pngbin0 -> 1683 bytes
-rw-r--r--data/tooth.pngbin0 -> 4931 bytes
-rw-r--r--examples/plot_convolutional_barycenter.py92
-rw-r--r--ot/bregman.py110
-rw-r--r--test/test_bregman.py24
-rw-r--r--test/test_stochastic.py8
9 files changed, 233 insertions, 5 deletions
diff --git a/README.md b/README.md
index 4d824ce..1c8114a 100644
--- a/README.md
+++ b/README.md
@@ -230,4 +230,6 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International Conference in Machine Learning
-[21] J. Altschuler, J.Weed, P. Rigollet, (2017) Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31 \ No newline at end of file
+[21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). [Convolutional wasserstein distances: Efficient optimal transportation on geometric domains](https://dl.acm.org/citation.cfm?id=2766963). ACM Transactions on Graphics (TOG), 34(4), 66.
+
+[21] J. Altschuler, J.Weed, P. Rigollet, (2017) [Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration](https://papers.nips.cc/paper/6792-near-linear-time-approximation-algorithms-for-optimal-transport-via-sinkhorn-iteration.pdf), Advances in Neural Information Processing Systems (NIPS) 31
diff --git a/data/duck.png b/data/duck.png
new file mode 100644
index 0000000..9181697
--- /dev/null
+++ b/data/duck.png
Binary files differ
diff --git a/data/heart.png b/data/heart.png
new file mode 100644
index 0000000..44a6385
--- /dev/null
+++ b/data/heart.png
Binary files differ
diff --git a/data/redcross.png b/data/redcross.png
new file mode 100644
index 0000000..8d0a6fa
--- /dev/null
+++ b/data/redcross.png
Binary files differ
diff --git a/data/tooth.png b/data/tooth.png
new file mode 100644
index 0000000..cd92c9d
--- /dev/null
+++ b/data/tooth.png
Binary files differ
diff --git a/examples/plot_convolutional_barycenter.py b/examples/plot_convolutional_barycenter.py
new file mode 100644
index 0000000..e74db04
--- /dev/null
+++ b/examples/plot_convolutional_barycenter.py
@@ -0,0 +1,92 @@
+
+#%%
+# -*- coding: utf-8 -*-
+"""
+============================================
+Convolutional Wasserstein Barycenter example
+============================================
+
+This example is designed to illustrate how the Convolutional Wasserstein Barycenter
+function of POT works.
+"""
+
+# Author: Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+
+import numpy as np
+import pylab as pl
+import ot
+
+##############################################################################
+# Data preparation
+# ----------------
+#
+# The four distributions are constructed from 4 simple images
+
+
+f1 = 1 - pl.imread('../data/redcross.png')[:, :, 2]
+f2 = 1 - pl.imread('../data/duck.png')[:, :, 2]
+f3 = 1 - pl.imread('../data/heart.png')[:, :, 2]
+f4 = 1 - pl.imread('../data/tooth.png')[:, :, 2]
+
+A = []
+f1 = f1 / np.sum(f1)
+f2 = f2 / np.sum(f2)
+f3 = f3 / np.sum(f3)
+f4 = f4 / np.sum(f4)
+A.append(f1)
+A.append(f2)
+A.append(f3)
+A.append(f4)
+A = np.array(A)
+
+nb_images = 5
+
+# those are the four corners coordinates that will be interpolated by bilinear
+# interpolation
+v1 = np.array((1, 0, 0, 0))
+v2 = np.array((0, 1, 0, 0))
+v3 = np.array((0, 0, 1, 0))
+v4 = np.array((0, 0, 0, 1))
+
+
+##############################################################################
+# Barycenter computation and visualization
+# ----------------------------------------
+#
+
+pl.figure(figsize=(10, 10))
+pl.title('Convolutional Wasserstein Barycenters in POT')
+cm = 'Blues'
+# regularization parameter
+reg = 0.004
+for i in range(nb_images):
+ for j in range(nb_images):
+ pl.subplot(nb_images, nb_images, i * nb_images + j + 1)
+ tx = float(i) / (nb_images - 1)
+ ty = float(j) / (nb_images - 1)
+
+ # weights are constructed by bilinear interpolation
+ tmp1 = (1 - tx) * v1 + tx * v2
+ tmp2 = (1 - tx) * v3 + tx * v4
+ weights = (1 - ty) * tmp1 + ty * tmp2
+
+ if i == 0 and j == 0:
+ pl.imshow(f1, cmap=cm)
+ pl.axis('off')
+ elif i == 0 and j == (nb_images - 1):
+ pl.imshow(f3, cmap=cm)
+ pl.axis('off')
+ elif i == (nb_images - 1) and j == 0:
+ pl.imshow(f2, cmap=cm)
+ pl.axis('off')
+ elif i == (nb_images - 1) and j == (nb_images - 1):
+ pl.imshow(f4, cmap=cm)
+ pl.axis('off')
+ else:
+ # call to barycenter computation
+ pl.imshow(ot.bregman.convolutional_barycenter2d(A, reg, weights), cmap=cm)
+ pl.axis('off')
+pl.show()
diff --git a/ot/bregman.py b/ot/bregman.py
index 1f5150a..418de57 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -1070,6 +1070,116 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
return geometricBar(weights, UKv)
+def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1e-9, stabThr=1e-30, verbose=False, log=False):
+ """Compute the entropic regularized wasserstein barycenter of distributions A
+ where A is a collection of 2D images.
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
+ - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions of matrix :math:`\mathbf{A}`
+ - reg is the regularization strength scalar value
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [21]_
+
+ Parameters
+ ----------
+ A : np.ndarray (n,w,h)
+ n distributions (2D images) of size w x h
+ reg : float
+ Regularization term >0
+ weights : np.ndarray (n,)
+ Weights of each image on the simplex (barycentric coodinates)
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ stabThr : float, optional
+ Stabilization threshold to avoid numerical precision issue
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ a : (w,h) ndarray
+ 2D Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015).
+ Convolutional wasserstein distances: Efficient optimal transportation on geometric domains
+ ACM Transactions on Graphics (TOG), 34(4), 66
+
+
+ """
+
+ if weights is None:
+ weights = np.ones(A.shape[0]) / A.shape[0]
+ else:
+ assert(len(weights) == A.shape[0])
+
+ if log:
+ log = {'err': []}
+
+ b = np.zeros_like(A[0, :, :])
+ U = np.ones_like(A)
+ KV = np.ones_like(A)
+
+ cpt = 0
+ err = 1
+
+ # build the convolution operator
+ t = np.linspace(0, 1, A.shape[1])
+ [Y, X] = np.meshgrid(t, t)
+ xi1 = np.exp(-(X - Y)**2 / reg)
+
+ def K(x):
+ return np.dot(np.dot(xi1, x), xi1)
+
+ while (err > stopThr and cpt < numItermax):
+
+ bold = b
+ cpt = cpt + 1
+
+ b = np.zeros_like(A[0, :, :])
+ for r in range(A.shape[0]):
+ KV[r, :, :] = K(A[r, :, :] / np.maximum(stabThr, K(U[r, :, :])))
+ b += weights[r] * np.log(np.maximum(stabThr, U[r, :, :] * KV[r, :, :]))
+ b = np.exp(b)
+ for r in range(A.shape[0]):
+ U[r, :, :] = b / np.maximum(stabThr, KV[r, :, :])
+
+ if cpt % 10 == 1:
+ err = np.sum(np.abs(bold - b))
+ # log and verbose print
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ if log:
+ log['niter'] = cpt
+ log['U'] = U
+ return b, log
+ else:
+ return b
+
+
def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
stopThr=1e-3, verbose=False, log=False):
"""
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 52bbbd2..58afd7a 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -108,6 +108,30 @@ def test_bary():
ot.bregman.barycenter(A, M, reg, log=True, verbose=True)
+def test_wassersteinbary():
+
+ size = 100 # size of a square image
+ a1 = np.random.randn(size, size)
+ a1 += a1.min()
+ a1 = a1 / np.sum(a1)
+ a2 = np.random.randn(size, size)
+ a2 += a2.min()
+ a2 = a2 / np.sum(a2)
+ # creating matrix A containing all distributions
+ A = np.zeros((2, 100, 100))
+ A[0, :, :] = a1
+ A[1, :, :] = a2
+
+ # wasserstein
+ reg = 1e-3
+ bary_wass = ot.bregman.convolutional_barycenter2d(A, reg)
+
+ np.testing.assert_allclose(1, np.sum(bary_wass))
+
+ # help in checking if log and verbose do not bug the function
+ ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
+
+
def test_unmix():
n_bins = 50 # nb bins
diff --git a/test/test_stochastic.py b/test/test_stochastic.py
index 0128317..f0f3fc8 100644
--- a/test/test_stochastic.py
+++ b/test/test_stochastic.py
@@ -32,7 +32,7 @@ def test_stochastic_sag():
# test sag
n = 15
reg = 1
- numItermax = 300000
+ numItermax = 30000
rng = np.random.RandomState(0)
x = rng.randn(n, 2)
@@ -62,7 +62,7 @@ def test_stochastic_asgd():
# test asgd
n = 15
reg = 1
- numItermax = 300000
+ numItermax = 100000
rng = np.random.RandomState(0)
x = rng.randn(n, 2)
@@ -92,7 +92,7 @@ def test_sag_asgd_sinkhorn():
# test all algorithms
n = 15
reg = 1
- nb_iter = 300000
+ nb_iter = 100000
rng = np.random.RandomState(0)
x = rng.randn(n, 2)
@@ -167,7 +167,7 @@ def test_dual_sgd_sinkhorn():
# test all dual algorithms
n = 10
reg = 1
- nb_iter = 150000
+ nb_iter = 15000
batch_size = 10
rng = np.random.RandomState(0)