summaryrefslogtreecommitdiff
path: root/examples/plot_convolutional_barycenter.py
diff options
context:
space:
mode:
authorNicolas Courty <ncourty@irisa.fr>2018-09-07 11:58:42 +0200
committerNicolas Courty <ncourty@irisa.fr>2018-09-07 11:58:42 +0200
commitd99abf078537acf6cf49480b9790a9c450889031 (patch)
treec9e1138752af5ea4b33d9d46766033386098dd28 /examples/plot_convolutional_barycenter.py
parent5180023fc49d15ad83faccc5674d5966fe9a0385 (diff)
Wasserstein convolutional barycenter
Diffstat (limited to 'examples/plot_convolutional_barycenter.py')
-rw-r--r--examples/plot_convolutional_barycenter.py92
1 files changed, 92 insertions, 0 deletions
diff --git a/examples/plot_convolutional_barycenter.py b/examples/plot_convolutional_barycenter.py
new file mode 100644
index 0000000..d231da9
--- /dev/null
+++ b/examples/plot_convolutional_barycenter.py
@@ -0,0 +1,92 @@
+
+#%%
+# -*- coding: utf-8 -*-
+"""
+============================================
+Convolutional Wasserstein Barycenter example
+============================================
+
+This example is designed to illustrate how the Convolutional Wasserstein Barycenter
+function of POT works.
+"""
+
+# Author: Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+
+import numpy as np
+import pylab as pl
+import ot
+
+##############################################################################
+# Data preparation
+# ----------------
+#
+# The four distributions are constructed from 4 simple images
+
+
+f1 = 1 - pl.imread('../data/redcross.png')[:, :, 2]
+f2 = 1 - pl.imread('../data/duck.png')[:, :, 2]
+f3 = 1 - pl.imread('../data/heart.png')[:, :, 2]
+f4 = 1 - pl.imread('../data/tooth.png')[:, :, 2]
+
+A = []
+f1=f1/np.sum(f1)
+f2=f2/np.sum(f2)
+f3=f3/np.sum(f3)
+f4=f4/np.sum(f4)
+A.append(f1)
+A.append(f2)
+A.append(f3)
+A.append(f4)
+A=np.array(A)
+
+nb_images = 5
+
+# those are the four corners coordinates that will be interpolated by bilinear
+# interpolation
+v1=np.array((1,0,0,0))
+v2=np.array((0,1,0,0))
+v3=np.array((0,0,1,0))
+v4=np.array((0,0,0,1))
+
+
+##############################################################################
+# Barycenter computation and visualization
+# ----------------------------------------
+#
+
+pl.figure(figsize=(10,10))
+pl.title('Convolutional Wasserstein Barycenters in POT')
+cm='Blues'
+# regularization parameter
+reg=0.004
+for i in range(nb_images):
+ for j in range(nb_images):
+ pl.subplot(nb_images,nb_images,i*nb_images+j+1)
+ tx=float(i)/(nb_images-1)
+ ty=float(j)/(nb_images-1)
+
+ # weights are constructed by bilinear interpolation
+ tmp1=(1-tx)*v1+tx*v2
+ tmp2=(1-tx)*v3+tx*v4
+ weights=(1-ty)*tmp1+ty*tmp2
+
+ if i==0 and j==0:
+ pl.imshow(f1,cmap=cm)
+ pl.axis('off')
+ elif i==0 and j==(nb_images-1):
+ pl.imshow(f3,cmap=cm)
+ pl.axis('off')
+ elif i==(nb_images-1) and j==0:
+ pl.imshow(f2,cmap=cm)
+ pl.axis('off')
+ elif i==(nb_images-1) and j==(nb_images-1):
+ pl.imshow(f4,cmap=cm)
+ pl.axis('off')
+ else:
+ # call to barycenter computation
+ pl.imshow(ot.convolutional_barycenter2d(A,reg,weights),cmap=cm)
+ pl.axis('off')
+pl.show() \ No newline at end of file