diff options
author | Hicham Janati <hicham.janati100@gmail.com> | 2021-11-03 08:41:35 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-11-03 08:41:35 +0100 |
commit | e1b67c641da3b3e497db6811af2c200022b10302 (patch) | |
tree | 44d42e1ae50d653bb07dd6ef9c1de14f71b21642 /examples/barycenters/plot_convolutional_barycenter.py | |
parent | 61340d526702616ff000d9e1cf71f52dd199a103 (diff) |
[WIP] Add debiased barycenter (Sinkhorn + convolutional sinkhorn) (#291)
* add debiased sinkhorn barycenter + make loops pythonic
* add debiased arg in tests
* add 1d and 2d examples of debiased barycenters
* fix doctest
* fix flake8
* pep8 + make func private + add convergence warnings
* remove rel paths + add rng + pylab to pyplot
* fix stopping criterion debiased
* pass alex
* change params with new API
* add logdomain barycenters + separate debiased API
* test new API
* fix jax read-only ?
* raise error for jax
* test catch jax error
* fix pytest catch error
* fix relative path
* fix flake8
* add warn arg everywhere
* fix ref number
* catch warnings in tests
* add contrib to readme + change ref number
* fix convolution example + gallery thumbnails
* increase coverage
* fix flake
Co-authored-by: Hicham Janati <hicham.janati@inria.fr>
Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'examples/barycenters/plot_convolutional_barycenter.py')
-rw-r--r-- | examples/barycenters/plot_convolutional_barycenter.py | 53 |
1 files changed, 25 insertions, 28 deletions
diff --git a/examples/barycenters/plot_convolutional_barycenter.py b/examples/barycenters/plot_convolutional_barycenter.py index cbcd4a1..3721f31 100644 --- a/examples/barycenters/plot_convolutional_barycenter.py +++ b/examples/barycenters/plot_convolutional_barycenter.py @@ -6,17 +6,18 @@ Convolutional Wasserstein Barycenter example ============================================ -This example is designed to illustrate how the Convolutional Wasserstein Barycenter -function of POT works. +This example is designed to illustrate how the Convolutional Wasserstein +Barycenter function of POT works. """ # Author: Nicolas Courty <ncourty@irisa.fr> # # License: MIT License - +import os +from pathlib import Path import numpy as np -import pylab as pl +import matplotlib.pyplot as plt import ot ############################################################################## @@ -25,22 +26,19 @@ import ot # # The four distributions are constructed from 4 simple images +this_file = os.path.realpath('__file__') +data_path = os.path.join(Path(this_file).parent.parent.parent, 'data') -f1 = 1 - pl.imread('../../data/redcross.png')[:, :, 2] -f2 = 1 - pl.imread('../../data/duck.png')[:, :, 2] -f3 = 1 - pl.imread('../../data/heart.png')[:, :, 2] -f4 = 1 - pl.imread('../../data/tooth.png')[:, :, 2] +f1 = 1 - plt.imread(os.path.join(data_path, 'redcross.png'))[:, :, 2] +f2 = 1 - plt.imread(os.path.join(data_path, 'tooth.png'))[:, :, 2] +f3 = 1 - plt.imread(os.path.join(data_path, 'heart.png'))[:, :, 2] +f4 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[:, :, 2] -A = [] f1 = f1 / np.sum(f1) f2 = f2 / np.sum(f2) f3 = f3 / np.sum(f3) f4 = f4 / np.sum(f4) -A.append(f1) -A.append(f2) -A.append(f3) -A.append(f4) -A = np.array(A) +A = np.array([f1, f2, f3, f4]) nb_images = 5 @@ -57,14 +55,13 @@ v4 = np.array((0, 0, 0, 1)) # ---------------------------------------- # -pl.figure(figsize=(10, 10)) -pl.title('Convolutional Wasserstein Barycenters in POT') +fig, axes = plt.subplots(nb_images, nb_images, figsize=(7, 7)) +plt.suptitle('Convolutional Wasserstein Barycenters in POT') cm = 'Blues' # regularization parameter reg = 0.004 for i in range(nb_images): for j in range(nb_images): - pl.subplot(nb_images, nb_images, i * nb_images + j + 1) tx = float(i) / (nb_images - 1) ty = float(j) / (nb_images - 1) @@ -74,19 +71,19 @@ for i in range(nb_images): weights = (1 - ty) * tmp1 + ty * tmp2 if i == 0 and j == 0: - pl.imshow(f1, cmap=cm) - pl.axis('off') + axes[i, j].imshow(f1, cmap=cm) elif i == 0 and j == (nb_images - 1): - pl.imshow(f3, cmap=cm) - pl.axis('off') + axes[i, j].imshow(f3, cmap=cm) elif i == (nb_images - 1) and j == 0: - pl.imshow(f2, cmap=cm) - pl.axis('off') + axes[i, j].imshow(f2, cmap=cm) elif i == (nb_images - 1) and j == (nb_images - 1): - pl.imshow(f4, cmap=cm) - pl.axis('off') + axes[i, j].imshow(f4, cmap=cm) else: # call to barycenter computation - pl.imshow(ot.bregman.convolutional_barycenter2d(A, reg, weights), cmap=cm) - pl.axis('off') -pl.show() + axes[i, j].imshow( + ot.bregman.convolutional_barycenter2d(A, reg, weights), + cmap=cm + ) + axes[i, j].axis('off') +plt.tight_layout() +plt.show() |