summaryrefslogtreecommitdiff
path: root/examples/unbalanced-partial/plot_unbalanced_OT.py
blob: 03487e7e2f0382e91c54a962b5891180837729d7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# -*- coding: utf-8 -*-
"""
==============================================================
2D examples of exact and entropic unbalanced optimal transport
==============================================================
This example is designed to show how to compute unbalanced and
partial OT in POT.

UOT aims at solving the following optimization problem:

    .. math::
        W = \min_{\gamma} <\gamma, \mathbf{M}>_F +
        \mathrm{reg}\cdot\Omega(\gamma) +
          \mathrm{reg_m} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) +
        \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b})

        s.t.
             \gamma \geq 0

where :math:`\mathrm{div}` is a divergence.
When using the entropic UOT, :math:`\mathrm{reg}>0` and :math:`\mathrm{div}`
should be the Kullback-Leibler divergence.
When solving exact UOT, :math:`\mathrm{reg}=0` and :math:`\mathrm{div}`
can be either the Kullback-Leibler or the quadratic divergence.
Using :math:`\ell_1` norm gives the so-called partial OT.
"""

# Author: Laetitia Chapel <laetitia.chapel@univ-ubs.fr>
# License: MIT License

import numpy as np
import matplotlib.pylab as pl
import ot

##############################################################################
# Generate data
# -------------

# %% parameters and data generation

n = 40  # nb samples

mu_s = np.array([-1, -1])
cov_s = np.array([[1, 0], [0, 1]])

mu_t = np.array([4, 4])
cov_t = np.array([[1, -.8], [-.8, 1]])

np.random.seed(0)
xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)

n_noise = 10

xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) - 4))), axis=0)
xt = np.concatenate((xt, ((np.random.rand(n_noise, 2) + 6))), axis=0)

n = n + n_noise

a, b = np.ones((n,)) / n, np.ones((n,)) / n  # uniform distribution on samples

# loss matrix
M = ot.dist(xs, xt)
M /= M.max()


##############################################################################
# Compute entropic kl-regularized UOT, kl- and l2-regularized UOT
# -----------

reg = 0.005
reg_m_kl = 0.05
reg_m_l2 = 5
mass = 0.7

entropic_kl_uot = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m_kl)
kl_uot = ot.unbalanced.mm_unbalanced(a, b, M, reg_m_kl, div='kl')
l2_uot = ot.unbalanced.mm_unbalanced(a, b, M, reg_m_l2, div='l2')
partial_ot = ot.partial.partial_wasserstein(a, b, M, m=mass)

##############################################################################
# Plot the results
# ----------------

pl.figure(2)
transp = [partial_ot, l2_uot, kl_uot, entropic_kl_uot]
title = ["partial OT \n m=" + str(mass), "$\ell_2$-UOT \n $\mathrm{reg_m}$=" +
         str(reg_m_l2), "kl-UOT \n $\mathrm{reg_m}$=" + str(reg_m_kl),
         "entropic kl-UOT \n $\mathrm{reg_m}$=" + str(reg_m_kl)]

for p in range(4):
    pl.subplot(2, 4, p + 1)
    P = transp[p]
    if P.sum() > 0:
        P = P / P.max()
    for i in range(n):
        for j in range(n):
            if P[i, j] > 0:
                pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2',
                        alpha=P[i, j] * 0.3)
    pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2)
    pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2)
    pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 2)
    pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 2)
    pl.title(title[p])
    pl.yticks(())
    pl.xticks(())
    if p < 1:
        pl.ylabel("mappings")
    pl.subplot(2, 4, p + 5)
    pl.imshow(P, cmap='jet')
    pl.yticks(())
    pl.xticks(())
    if p < 1:
        pl.ylabel("transport plans")
pl.show()