1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
|
#%%
# -*- coding: utf-8 -*-
"""
============================================
Convolutional Wasserstein Barycenter example
============================================
This example is designed to illustrate how the Convolutional Wasserstein Barycenter
function of POT works.
"""
# Author: Nicolas Courty <ncourty@irisa.fr>
#
# License: MIT License
import numpy as np
import pylab as pl
import ot
##############################################################################
# Data preparation
# ----------------
#
# The four distributions are constructed from 4 simple images
f1 = 1 - pl.imread('../data/redcross.png')[:, :, 2]
f2 = 1 - pl.imread('../data/duck.png')[:, :, 2]
f3 = 1 - pl.imread('../data/heart.png')[:, :, 2]
f4 = 1 - pl.imread('../data/tooth.png')[:, :, 2]
A = []
f1=f1/np.sum(f1)
f2=f2/np.sum(f2)
f3=f3/np.sum(f3)
f4=f4/np.sum(f4)
A.append(f1)
A.append(f2)
A.append(f3)
A.append(f4)
A=np.array(A)
nb_images = 5
# those are the four corners coordinates that will be interpolated by bilinear
# interpolation
v1=np.array((1,0,0,0))
v2=np.array((0,1,0,0))
v3=np.array((0,0,1,0))
v4=np.array((0,0,0,1))
##############################################################################
# Barycenter computation and visualization
# ----------------------------------------
#
pl.figure(figsize=(10,10))
pl.title('Convolutional Wasserstein Barycenters in POT')
cm='Blues'
# regularization parameter
reg=0.004
for i in range(nb_images):
for j in range(nb_images):
pl.subplot(nb_images,nb_images,i*nb_images+j+1)
tx=float(i)/(nb_images-1)
ty=float(j)/(nb_images-1)
# weights are constructed by bilinear interpolation
tmp1=(1-tx)*v1+tx*v2
tmp2=(1-tx)*v3+tx*v4
weights=(1-ty)*tmp1+ty*tmp2
if i==0 and j==0:
pl.imshow(f1,cmap=cm)
pl.axis('off')
elif i==0 and j==(nb_images-1):
pl.imshow(f3,cmap=cm)
pl.axis('off')
elif i==(nb_images-1) and j==0:
pl.imshow(f2,cmap=cm)
pl.axis('off')
elif i==(nb_images-1) and j==(nb_images-1):
pl.imshow(f4,cmap=cm)
pl.axis('off')
else:
# call to barycenter computation
pl.imshow(ot.convolutional_barycenter2d(A,reg,weights),cmap=cm)
pl.axis('off')
pl.show()
|