diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2018-09-24 10:29:37 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-09-24 10:29:37 +0200 |
commit | c9b99df8fffec1dcc6802ef43b6192774817c5fb (patch) | |
tree | 22939513930c1dd3c28fe93d90f2a7a284a0f82f | |
parent | 4367a343aeb0ceccbb99acc0f92797af020bb537 (diff) | |
parent | ccbe274fd9554492bb88ddaf530c2800a8dc3418 (diff) |
Merge pull request #64 from rflamary/convolution
[MRG] Wasserstein convolutional barycenter
This PR closes Issue #51
-rw-r--r-- | README.md | 2 | ||||
-rw-r--r-- | data/duck.png | bin | 0 -> 5112 bytes | |||
-rw-r--r-- | data/heart.png | bin | 0 -> 5225 bytes | |||
-rw-r--r-- | data/redcross.png | bin | 0 -> 1683 bytes | |||
-rw-r--r-- | data/tooth.png | bin | 0 -> 4931 bytes | |||
-rw-r--r-- | examples/plot_convolutional_barycenter.py | 92 | ||||
-rw-r--r-- | ot/bregman.py | 110 | ||||
-rw-r--r-- | test/test_bregman.py | 24 | ||||
-rw-r--r-- | test/test_stochastic.py | 8 |
9 files changed, 232 insertions, 4 deletions
@@ -228,3 +228,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. [Large-scale Optimal Transport and Mapping Estimation](https://arxiv.org/pdf/1711.02283.pdf). International Conference on Learning Representation (2018) [20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International Conference in Machine Learning + +[21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). [Convolutional wasserstein distances: Efficient optimal transportation on geometric domains](https://dl.acm.org/citation.cfm?id=2766963). ACM Transactions on Graphics (TOG), 34(4), 66. diff --git a/data/duck.png b/data/duck.png Binary files differnew file mode 100644 index 0000000..9181697 --- /dev/null +++ b/data/duck.png diff --git a/data/heart.png b/data/heart.png Binary files differnew file mode 100644 index 0000000..44a6385 --- /dev/null +++ b/data/heart.png diff --git a/data/redcross.png b/data/redcross.png Binary files differnew file mode 100644 index 0000000..8d0a6fa --- /dev/null +++ b/data/redcross.png diff --git a/data/tooth.png b/data/tooth.png Binary files differnew file mode 100644 index 0000000..cd92c9d --- /dev/null +++ b/data/tooth.png diff --git a/examples/plot_convolutional_barycenter.py b/examples/plot_convolutional_barycenter.py new file mode 100644 index 0000000..e74db04 --- /dev/null +++ b/examples/plot_convolutional_barycenter.py @@ -0,0 +1,92 @@ + +#%% +# -*- coding: utf-8 -*- +""" +============================================ +Convolutional Wasserstein Barycenter example +============================================ + +This example is designed to illustrate how the Convolutional Wasserstein Barycenter +function of POT works. +""" + +# Author: Nicolas Courty <ncourty@irisa.fr> +# +# License: MIT License + + +import numpy as np +import pylab as pl +import ot + +############################################################################## +# Data preparation +# ---------------- +# +# The four distributions are constructed from 4 simple images + + +f1 = 1 - pl.imread('../data/redcross.png')[:, :, 2] +f2 = 1 - pl.imread('../data/duck.png')[:, :, 2] +f3 = 1 - pl.imread('../data/heart.png')[:, :, 2] +f4 = 1 - pl.imread('../data/tooth.png')[:, :, 2] + +A = [] +f1 = f1 / np.sum(f1) +f2 = f2 / np.sum(f2) +f3 = f3 / np.sum(f3) +f4 = f4 / np.sum(f4) +A.append(f1) +A.append(f2) +A.append(f3) +A.append(f4) +A = np.array(A) + +nb_images = 5 + +# those are the four corners coordinates that will be interpolated by bilinear +# interpolation +v1 = np.array((1, 0, 0, 0)) +v2 = np.array((0, 1, 0, 0)) +v3 = np.array((0, 0, 1, 0)) +v4 = np.array((0, 0, 0, 1)) + + +############################################################################## +# Barycenter computation and visualization +# ---------------------------------------- +# + +pl.figure(figsize=(10, 10)) +pl.title('Convolutional Wasserstein Barycenters in POT') +cm = 'Blues' +# regularization parameter +reg = 0.004 +for i in range(nb_images): + for j in range(nb_images): + pl.subplot(nb_images, nb_images, i * nb_images + j + 1) + tx = float(i) / (nb_images - 1) + ty = float(j) / (nb_images - 1) + + # weights are constructed by bilinear interpolation + tmp1 = (1 - tx) * v1 + tx * v2 + tmp2 = (1 - tx) * v3 + tx * v4 + weights = (1 - ty) * tmp1 + ty * tmp2 + + if i == 0 and j == 0: + pl.imshow(f1, cmap=cm) + pl.axis('off') + elif i == 0 and j == (nb_images - 1): + pl.imshow(f3, cmap=cm) + pl.axis('off') + elif i == (nb_images - 1) and j == 0: + pl.imshow(f2, cmap=cm) + pl.axis('off') + elif i == (nb_images - 1) and j == (nb_images - 1): + pl.imshow(f4, cmap=cm) + pl.axis('off') + else: + # call to barycenter computation + pl.imshow(ot.bregman.convolutional_barycenter2d(A, reg, weights), cmap=cm) + pl.axis('off') +pl.show() diff --git a/ot/bregman.py b/ot/bregman.py index c755f51..35e51f8 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -919,6 +919,116 @@ def barycenter(A, M, reg, weights=None, numItermax=1000, return geometricBar(weights, UKv) +def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1e-9, stabThr=1e-30, verbose=False, log=False): + """Compute the entropic regularized wasserstein barycenter of distributions A + where A is a collection of 2D images. + + The function solves the following optimization problem: + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn) + - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions of matrix :math:`\mathbf{A}` + - reg is the regularization strength scalar value + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [21]_ + + Parameters + ---------- + A : np.ndarray (n,w,h) + n distributions (2D images) of size w x h + reg : float + Regularization term >0 + weights : np.ndarray (n,) + Weights of each image on the simplex (barycentric coodinates) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + stabThr : float, optional + Stabilization threshold to avoid numerical precision issue + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + a : (w,h) ndarray + 2D Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). + Convolutional wasserstein distances: Efficient optimal transportation on geometric domains + ACM Transactions on Graphics (TOG), 34(4), 66 + + + """ + + if weights is None: + weights = np.ones(A.shape[0]) / A.shape[0] + else: + assert(len(weights) == A.shape[0]) + + if log: + log = {'err': []} + + b = np.zeros_like(A[0, :, :]) + U = np.ones_like(A) + KV = np.ones_like(A) + + cpt = 0 + err = 1 + + # build the convolution operator + t = np.linspace(0, 1, A.shape[1]) + [Y, X] = np.meshgrid(t, t) + xi1 = np.exp(-(X - Y)**2 / reg) + + def K(x): + return np.dot(np.dot(xi1, x), xi1) + + while (err > stopThr and cpt < numItermax): + + bold = b + cpt = cpt + 1 + + b = np.zeros_like(A[0, :, :]) + for r in range(A.shape[0]): + KV[r, :, :] = K(A[r, :, :] / np.maximum(stabThr, K(U[r, :, :]))) + b += weights[r] * np.log(np.maximum(stabThr, U[r, :, :] * KV[r, :, :])) + b = np.exp(b) + for r in range(A.shape[0]): + U[r, :, :] = b / np.maximum(stabThr, KV[r, :, :]) + + if cpt % 10 == 1: + err = np.sum(np.abs(bold - b)) + # log and verbose print + if log: + log['err'].append(err) + + if verbose: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + + if log: + log['niter'] = cpt + log['U'] = U + return b, log + else: + return b + + def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, stopThr=1e-3, verbose=False, log=False): """ diff --git a/test/test_bregman.py b/test/test_bregman.py index c8e9179..01ec655 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -105,6 +105,30 @@ def test_bary(): ot.bregman.barycenter(A, M, reg, log=True, verbose=True) +def test_wassersteinbary(): + + size = 100 # size of a square image + a1 = np.random.randn(size, size) + a1 += a1.min() + a1 = a1 / np.sum(a1) + a2 = np.random.randn(size, size) + a2 += a2.min() + a2 = a2 / np.sum(a2) + # creating matrix A containing all distributions + A = np.zeros((2, 100, 100)) + A[0, :, :] = a1 + A[1, :, :] = a2 + + # wasserstein + reg = 1e-3 + bary_wass = ot.bregman.convolutional_barycenter2d(A, reg) + + np.testing.assert_allclose(1, np.sum(bary_wass)) + + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) + + def test_unmix(): n_bins = 50 # nb bins diff --git a/test/test_stochastic.py b/test/test_stochastic.py index 0128317..f0f3fc8 100644 --- a/test/test_stochastic.py +++ b/test/test_stochastic.py @@ -32,7 +32,7 @@ def test_stochastic_sag(): # test sag n = 15 reg = 1 - numItermax = 300000 + numItermax = 30000 rng = np.random.RandomState(0) x = rng.randn(n, 2) @@ -62,7 +62,7 @@ def test_stochastic_asgd(): # test asgd n = 15 reg = 1 - numItermax = 300000 + numItermax = 100000 rng = np.random.RandomState(0) x = rng.randn(n, 2) @@ -92,7 +92,7 @@ def test_sag_asgd_sinkhorn(): # test all algorithms n = 15 reg = 1 - nb_iter = 300000 + nb_iter = 100000 rng = np.random.RandomState(0) x = rng.randn(n, 2) @@ -167,7 +167,7 @@ def test_dual_sgd_sinkhorn(): # test all dual algorithms n = 10 reg = 1 - nb_iter = 150000 + nb_iter = 15000 batch_size = 10 rng = np.random.RandomState(0) |