From 897026ea1f5c35ba9e881433bc61490e70776b8c Mon Sep 17 00:00:00 2001 From: Huy Tran Date: Wed, 22 Mar 2023 08:13:53 +0100 Subject: [MRG] CO-Optimal Transport solver (#447) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Allow warmstart in sinkhorn and sinkhorn_log * Added argument for warmstart of dual vectors in Sinkhorn-based methods in * Add the number of the PR * [WIP] CO-Optimal Transport * Revert "[WIP] CO-Optimal Transport" This reverts commit f3d36b2705013409ac69b346585e311bc25fcfb7. * reformat with PEP8 * Fix W291 trailing whitespace error in pep8 test * Rearange position of warmstart argument and edit its description * Implementation of CO-Optimal Transport * Optimize code and edit documentation * fix backend bug in test cases * fix backend bug * fix backend bug * Add examples on COOT * Modify API and edit example * Edit API * minor edit of examples and release * fix bug in coot * fix doc examples * more fix of doc * restart CI * reordering ref * add more tests * add more tests * add test verbose * fix PEP8 bug * fix PEP8 bug * fix PEP8 bug * fix pytest bug * edit doc for better display --------- Co-authored-by: RĂ©mi Flamary Co-authored-by: Alexandre Gramfort --- examples/others/plot_COOT.py | 97 +++++++++++++ examples/others/plot_learning_weights_with_COOT.py | 150 +++++++++++++++++++++ 2 files changed, 247 insertions(+) create mode 100644 examples/others/plot_COOT.py create mode 100644 examples/others/plot_learning_weights_with_COOT.py (limited to 'examples') diff --git a/examples/others/plot_COOT.py b/examples/others/plot_COOT.py new file mode 100644 index 0000000..98c1ce1 --- /dev/null +++ b/examples/others/plot_COOT.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +r""" +=================================================== +Row and column alignments with CO-Optimal Transport +=================================================== + +This example is designed to show how to use the CO-Optimal Transport [47]_ in POT. +CO-Optimal Transport allows to calculate the distance between two **arbitrary-size** +matrices, and to align their rows and columns. In this example, we consider two +random matrices :math:`X_1` and :math:`X_2` defined by +:math:`(X_1)_{i,j} = \cos(\frac{i}{n_1} \pi) + \cos(\frac{j}{d_1} \pi) + \sigma \mathcal N(0,1)` +and :math:`(X_2)_{i,j} = \cos(\frac{i}{n_2} \pi) + \cos(\frac{j}{d_2} \pi) + \sigma \mathcal N(0,1)`. + +.. [49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020). + `CO-Optimal Transport `_. + Advances in Neural Information Processing Systems, 33. +""" + +# Author: Remi Flamary +# Quang Huy Tran +# License: MIT License + +from matplotlib.patches import ConnectionPatch +import matplotlib.pylab as pl +import numpy as np +from ot.coot import co_optimal_transport as coot +from ot.coot import co_optimal_transport2 as coot2 + +# %% +# Generating two random matrices + +n1 = 20 +n2 = 10 +d1 = 16 +d2 = 8 +sigma = 0.2 + +X1 = ( + np.cos(np.arange(n1) * np.pi / n1)[:, None] + + np.cos(np.arange(d1) * np.pi / d1)[None, :] + + sigma * np.random.randn(n1, d1) +) +X2 = ( + np.cos(np.arange(n2) * np.pi / n2)[:, None] + + np.cos(np.arange(d2) * np.pi / d2)[None, :] + + sigma * np.random.randn(n2, d2) +) + +# %% +# Visualizing the matrices + +pl.figure(1, (8, 5)) +pl.subplot(1, 2, 1) +pl.imshow(X1) +pl.title('$X_1$') + +pl.subplot(1, 2, 2) +pl.imshow(X2) +pl.title("$X_2$") + +pl.tight_layout() + +# %% +# Visualizing the alignments of rows and columns, and calculating the CO-Optimal Transport distance + +pi_sample, pi_feature, log = coot(X1, X2, log=True, verbose=True) +coot_distance = coot2(X1, X2) +print('CO-Optimal Transport distance = {:.5f}'.format(coot_distance)) + +fig = pl.figure(4, (9, 7)) +pl.clf() + +ax1 = pl.subplot(2, 2, 3) +pl.imshow(X1) +pl.xlabel('$X_1$') + +ax2 = pl.subplot(2, 2, 2) +ax2.yaxis.tick_right() +pl.imshow(np.transpose(X2)) +pl.title("Transpose($X_2$)") +ax2.xaxis.tick_top() + +for i in range(n1): + j = np.argmax(pi_sample[i, :]) + xyA = (d1 - .5, i) + xyB = (j, d2 - .5) + con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA=ax1.transData, + coordsB=ax2.transData, color="black") + fig.add_artist(con) + +for i in range(d1): + j = np.argmax(pi_feature[i, :]) + xyA = (i, -.5) + xyB = (-.5, j) + con = ConnectionPatch( + xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue") + fig.add_artist(con) diff --git a/examples/others/plot_learning_weights_with_COOT.py b/examples/others/plot_learning_weights_with_COOT.py new file mode 100644 index 0000000..cb115c3 --- /dev/null +++ b/examples/others/plot_learning_weights_with_COOT.py @@ -0,0 +1,150 @@ +# -*- coding: utf-8 -*- +r""" +=============================================================== +Learning sample marginal distribution with CO-Optimal Transport +=============================================================== + +In this example, we illustrate how to estimate the sample marginal distribution which minimizes +the CO-Optimal Transport distance [47]_ between two matrices. More precisely, given a source data +:math:`(X, \mu_x^{(s)}, \mu_x^{(f)})` and a target matrix :math:`Y` associated with a fixed +histogram on features :math:`\mu_y^{(f)}`, we want to solve the following problem + +.. math:: + \min_{\mu_y^{(s)} \in \Delta} \text{COOT}\left( (X, \mu_x^{(s)}, \mu_x^{(f)}), (Y, \mu_y^{(s)}, \mu_y^{(f)}) \right) + +where :math:`\Delta` is the probability simplex. This minimization is done with a +simple projected gradient descent in PyTorch. We use the automatic backend of POT that +allows us to compute the CO-Optimal Transport distance with :func:`ot.coot.co_optimal_transport2` +with differentiable losses. + +.. [49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020). + `CO-Optimal Transport `_. + Advances in Neural Information Processing Systems, 33. +""" + +# Author: Remi Flamary +# Quang Huy Tran +# License: MIT License + +from matplotlib.patches import ConnectionPatch +import torch +import numpy as np + +import matplotlib.pyplot as pl +import ot + +from ot.coot import co_optimal_transport as coot +from ot.coot import co_optimal_transport2 as coot2 + + +# %% +# Generate data +# ------------- +# The source and clean target matrices are generated by +# :math:`X_{i,j} = \cos(\frac{i}{n_1} \pi) + \cos(\frac{j}{d_1} \pi)` and +# :math:`Y_{i,j} = \cos(\frac{i}{n_2} \pi) + \cos(\frac{j}{d_2} \pi)`. +# The target matrix is then contaminated by adding 5 row outliers. +# Intuitively, we expect that the estimated sample distribution should ignore these outliers, +# i.e. their weights should be zero. + +np.random.seed(182) + +n1, d1 = 20, 16 +n2, d2 = 10, 8 +n = 15 + +X = ( + torch.cos(torch.arange(n1) * torch.pi / n1)[:, None] + + torch.cos(torch.arange(d1) * torch.pi / d1)[None, :] +) + +# Generate clean target data mixed with outliers +Y_noisy = torch.randn((n, d2)) * 10.0 +Y_noisy[:n2, :] = ( + torch.cos(torch.arange(n2) * torch.pi / n2)[:, None] + + torch.cos(torch.arange(d2) * torch.pi / d2)[None, :] +) +Y = Y_noisy[:n2, :] + +X, Y_noisy, Y = X.double(), Y_noisy.double(), Y.double() + +fig, axes = pl.subplots(nrows=1, ncols=3, figsize=(12, 5)) +axes[0].imshow(X, vmin=-2, vmax=2) +axes[0].set_title('$X$') + +axes[1].imshow(Y, vmin=-2, vmax=2) +axes[1].set_title('Clean $Y$') + +axes[2].imshow(Y_noisy, vmin=-2, vmax=2) +axes[2].set_title('Noisy $Y$') + +pl.tight_layout() + +# %% +# Optimize the COOT distance with respect to the sample marginal distribution +# --------------------------------------------------------------------------- + +losses = [] +lr = 1e-3 +niter = 1000 + +b = torch.tensor(ot.unif(n), requires_grad=True) + +for i in range(niter): + + loss = coot2(X, Y_noisy, wy_samp=b, log=False, verbose=False) + losses.append(float(loss)) + + loss.backward() + + with torch.no_grad(): + b -= lr * b.grad # gradient step + b[:] = ot.utils.proj_simplex(b) # projection on the simplex + + b.grad.zero_() + +# Estimated sample marginal distribution and training loss curve +pl.plot(losses[10:]) +pl.title('CO-Optimal Transport distance') + +print(f"Marginal distribution = {b.detach().numpy()}") + +# %% +# Visualizing the row and column alignments with the estimated sample marginal distribution +# ----------------------------------------------------------------------------------------- +# +# Clearly, the learned marginal distribution completely and successfully ignores the 5 outliers. + +X, Y_noisy = X.numpy(), Y_noisy.numpy() +b = b.detach().numpy() + +pi_sample, pi_feature = coot(X, Y_noisy, wy_samp=b, log=False, verbose=True) + +fig = pl.figure(4, (9, 7)) +pl.clf() + +ax1 = pl.subplot(2, 2, 3) +pl.imshow(X, vmin=-2, vmax=2) +pl.xlabel('$X$') + +ax2 = pl.subplot(2, 2, 2) +ax2.yaxis.tick_right() +pl.imshow(np.transpose(Y_noisy), vmin=-2, vmax=2) +pl.title("Transpose(Noisy $Y$)") +ax2.xaxis.tick_top() + +for i in range(n1): + j = np.argmax(pi_sample[i, :]) + xyA = (d1 - .5, i) + xyB = (j, d2 - .5) + con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA=ax1.transData, + coordsB=ax2.transData, color="black") + fig.add_artist(con) + +for i in range(d1): + j = np.argmax(pi_feature[i, :]) + xyA = (i, -.5) + xyB = (-.5, j) + con = ConnectionPatch( + xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue") + fig.add_artist(con) -- cgit v1.2.3