summaryrefslogtreecommitdiff
path: root/examples/barycenters/plot_generalized_free_support_barycenter.py
blob: e685ec71104d7318f0def4a3f902e64cdaf94b88 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# -*- coding: utf-8 -*-
"""
=======================================
Generalized Wasserstein Barycenter Demo
=======================================

This example illustrates the computation of Generalized Wasserstein Barycenter
as proposed in [42].


[42] Delon, J., Gozlan, N., and Saint-Dizier, A..
Generalized Wasserstein barycenters between probability measures living on different subspaces.
arXiv preprint arXiv:2105.09755, 2021.

"""

# Author: Eloi Tanguy <eloi.tanguy@polytechnique.edu>
#
# License: MIT License

# sphinx_gallery_thumbnail_number = 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.pylab as pl
import ot
import matplotlib.animation as animation

########################
# Generate and plot data
# ----------------------

# Input measures
sub_sample_factor = 8
I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::sub_sample_factor, ::sub_sample_factor, 2]
I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::-sub_sample_factor, ::sub_sample_factor, 2]
I3 = pl.imread('../../data/heart.png').astype(np.float64)[::-sub_sample_factor, ::sub_sample_factor, 2]

sz = I1.shape[0]
UU, VV = np.meshgrid(np.arange(sz), np.arange(sz))

# Input measure locations in their respective 2D spaces
X_list = [np.stack((UU[im == 0], VV[im == 0]), 1) * 1.0 for im in [I1, I2, I3]]

# Input measure weights
a_list = [ot.unif(x.shape[0]) for x in X_list]

# Projections 3D -> 2D
P1 = np.array([[1, 0, 0], [0, 1, 0]])
P2 = np.array([[0, 1, 0], [0, 0, 1]])
P3 = np.array([[1, 0, 0], [0, 0, 1]])
P_list = [P1, P2, P3]

# Barycenter weights
weights = np.array([1 / 3, 1 / 3, 1 / 3])

# Number of barycenter points to compute
n_samples_bary = 150

# Send the input measures into 3D space for visualisation
X_visu = [Xi @ Pi for (Xi, Pi) in zip(X_list, P_list)]

# Plot the input data
fig = plt.figure(figsize=(3, 3))
axis = fig.add_subplot(1, 1, 1, projection="3d")
for Xi in X_visu:
    axis.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6)
axis.view_init(azim=45)
axis.set_xticks([])
axis.set_yticks([])
axis.set_zticks([])
plt.show()

#################################
# Barycenter computation and plot
# -------------------------------

Y = ot.lp.generalized_free_support_barycenter(X_list, a_list, P_list, n_samples_bary)
fig = plt.figure(figsize=(3, 3))

axis = fig.add_subplot(1, 1, 1, projection="3d")
for Xi in X_visu:
    axis.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6)
axis.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6)
axis.view_init(azim=45)
axis.set_xticks([])
axis.set_yticks([])
axis.set_zticks([])
plt.show()


#############################
# Plotting projection matches
# ---------------------------

fig = plt.figure(figsize=(9, 3))

ax = fig.add_subplot(1, 3, 1, projection='3d')
for Xi in X_visu:
    ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6)
ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6)
ax.view_init(elev=0, azim=0)
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])

ax = fig.add_subplot(1, 3, 2, projection='3d')
for Xi in X_visu:
    ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6)
ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6)
ax.view_init(elev=0, azim=90)
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])

ax = fig.add_subplot(1, 3, 3, projection='3d')
for Xi in X_visu:
    ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6)
ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6)
ax.view_init(elev=90, azim=0)
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])

plt.tight_layout()
plt.show()

##############################################
# Rotation animation
# --------------------------------------------

fig = plt.figure(figsize=(7, 7))
ax = fig.add_subplot(1, 1, 1, projection="3d")


def _init():
    for Xi in X_visu:
        ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6)
    ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6)
    ax.view_init(elev=0, azim=0)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_zticks([])
    return fig,


def _update_plot(i):
    if i < 45:
        ax.view_init(elev=0, azim=4 * i)
    else:
        ax.view_init(elev=i - 45, azim=4 * i)
    return fig,


ani = animation.FuncAnimation(fig, _update_plot, init_func=_init, frames=136, interval=50, blit=True, repeat_delay=2000)