summaryrefslogtreecommitdiff
path: root/examples/others/plot_learning_weights_with_COOT.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/others/plot_learning_weights_with_COOT.py')
-rw-r--r--examples/others/plot_learning_weights_with_COOT.py150
1 files changed, 150 insertions, 0 deletions
diff --git a/examples/others/plot_learning_weights_with_COOT.py b/examples/others/plot_learning_weights_with_COOT.py
new file mode 100644
index 0000000..cb115c3
--- /dev/null
+++ b/examples/others/plot_learning_weights_with_COOT.py
@@ -0,0 +1,150 @@
+# -*- coding: utf-8 -*-
+r"""
+===============================================================
+Learning sample marginal distribution with CO-Optimal Transport
+===============================================================
+
+In this example, we illustrate how to estimate the sample marginal distribution which minimizes
+the CO-Optimal Transport distance [47]_ between two matrices. More precisely, given a source data
+:math:`(X, \mu_x^{(s)}, \mu_x^{(f)})` and a target matrix :math:`Y` associated with a fixed
+histogram on features :math:`\mu_y^{(f)}`, we want to solve the following problem
+
+.. math::
+ \min_{\mu_y^{(s)} \in \Delta} \text{COOT}\left( (X, \mu_x^{(s)}, \mu_x^{(f)}), (Y, \mu_y^{(s)}, \mu_y^{(f)}) \right)
+
+where :math:`\Delta` is the probability simplex. This minimization is done with a
+simple projected gradient descent in PyTorch. We use the automatic backend of POT that
+allows us to compute the CO-Optimal Transport distance with :func:`ot.coot.co_optimal_transport2`
+with differentiable losses.
+
+.. [49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020).
+ `CO-Optimal Transport <https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf>`_.
+ Advances in Neural Information Processing Systems, 33.
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+# Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
+# License: MIT License
+
+from matplotlib.patches import ConnectionPatch
+import torch
+import numpy as np
+
+import matplotlib.pyplot as pl
+import ot
+
+from ot.coot import co_optimal_transport as coot
+from ot.coot import co_optimal_transport2 as coot2
+
+
+# %%
+# Generate data
+# -------------
+# The source and clean target matrices are generated by
+# :math:`X_{i,j} = \cos(\frac{i}{n_1} \pi) + \cos(\frac{j}{d_1} \pi)` and
+# :math:`Y_{i,j} = \cos(\frac{i}{n_2} \pi) + \cos(\frac{j}{d_2} \pi)`.
+# The target matrix is then contaminated by adding 5 row outliers.
+# Intuitively, we expect that the estimated sample distribution should ignore these outliers,
+# i.e. their weights should be zero.
+
+np.random.seed(182)
+
+n1, d1 = 20, 16
+n2, d2 = 10, 8
+n = 15
+
+X = (
+ torch.cos(torch.arange(n1) * torch.pi / n1)[:, None] +
+ torch.cos(torch.arange(d1) * torch.pi / d1)[None, :]
+)
+
+# Generate clean target data mixed with outliers
+Y_noisy = torch.randn((n, d2)) * 10.0
+Y_noisy[:n2, :] = (
+ torch.cos(torch.arange(n2) * torch.pi / n2)[:, None] +
+ torch.cos(torch.arange(d2) * torch.pi / d2)[None, :]
+)
+Y = Y_noisy[:n2, :]
+
+X, Y_noisy, Y = X.double(), Y_noisy.double(), Y.double()
+
+fig, axes = pl.subplots(nrows=1, ncols=3, figsize=(12, 5))
+axes[0].imshow(X, vmin=-2, vmax=2)
+axes[0].set_title('$X$')
+
+axes[1].imshow(Y, vmin=-2, vmax=2)
+axes[1].set_title('Clean $Y$')
+
+axes[2].imshow(Y_noisy, vmin=-2, vmax=2)
+axes[2].set_title('Noisy $Y$')
+
+pl.tight_layout()
+
+# %%
+# Optimize the COOT distance with respect to the sample marginal distribution
+# ---------------------------------------------------------------------------
+
+losses = []
+lr = 1e-3
+niter = 1000
+
+b = torch.tensor(ot.unif(n), requires_grad=True)
+
+for i in range(niter):
+
+ loss = coot2(X, Y_noisy, wy_samp=b, log=False, verbose=False)
+ losses.append(float(loss))
+
+ loss.backward()
+
+ with torch.no_grad():
+ b -= lr * b.grad # gradient step
+ b[:] = ot.utils.proj_simplex(b) # projection on the simplex
+
+ b.grad.zero_()
+
+# Estimated sample marginal distribution and training loss curve
+pl.plot(losses[10:])
+pl.title('CO-Optimal Transport distance')
+
+print(f"Marginal distribution = {b.detach().numpy()}")
+
+# %%
+# Visualizing the row and column alignments with the estimated sample marginal distribution
+# -----------------------------------------------------------------------------------------
+#
+# Clearly, the learned marginal distribution completely and successfully ignores the 5 outliers.
+
+X, Y_noisy = X.numpy(), Y_noisy.numpy()
+b = b.detach().numpy()
+
+pi_sample, pi_feature = coot(X, Y_noisy, wy_samp=b, log=False, verbose=True)
+
+fig = pl.figure(4, (9, 7))
+pl.clf()
+
+ax1 = pl.subplot(2, 2, 3)
+pl.imshow(X, vmin=-2, vmax=2)
+pl.xlabel('$X$')
+
+ax2 = pl.subplot(2, 2, 2)
+ax2.yaxis.tick_right()
+pl.imshow(np.transpose(Y_noisy), vmin=-2, vmax=2)
+pl.title("Transpose(Noisy $Y$)")
+ax2.xaxis.tick_top()
+
+for i in range(n1):
+ j = np.argmax(pi_sample[i, :])
+ xyA = (d1 - .5, i)
+ xyB = (j, d2 - .5)
+ con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA=ax1.transData,
+ coordsB=ax2.transData, color="black")
+ fig.add_artist(con)
+
+for i in range(d1):
+ j = np.argmax(pi_feature[i, :])
+ xyA = (i, -.5)
+ xyB = (-.5, j)
+ con = ConnectionPatch(
+ xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue")
+ fig.add_artist(con)