summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHuy Tran <huytran82125@gmail.com>2023-03-22 08:13:53 +0100
committerGitHub <noreply@github.com>2023-03-22 08:13:53 +0100
commit897026ea1f5c35ba9e881433bc61490e70776b8c (patch)
tree2cedb9c1a0faa971f6e78c5574d7411ce8e55079
parentb9ed7b1650475420cc5bbec6c31476cc098790d5 (diff)
[MRG] CO-Optimal Transport solver (#447)
* Allow warmstart in sinkhorn and sinkhorn_log * Added argument for warmstart of dual vectors in Sinkhorn-based methods in * Add the number of the PR * [WIP] CO-Optimal Transport * Revert "[WIP] CO-Optimal Transport" This reverts commit f3d36b2705013409ac69b346585e311bc25fcfb7. * reformat with PEP8 * Fix W291 trailing whitespace error in pep8 test * Rearange position of warmstart argument and edit its description * Implementation of CO-Optimal Transport * Optimize code and edit documentation * fix backend bug in test cases * fix backend bug * fix backend bug * Add examples on COOT * Modify API and edit example * Edit API * minor edit of examples and release * fix bug in coot * fix doc examples * more fix of doc * restart CI * reordering ref * add more tests * add more tests * add test verbose * fix PEP8 bug * fix PEP8 bug * fix PEP8 bug * fix pytest bug * edit doc for better display --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com> Co-authored-by: Alexandre Gramfort <agramfort@fb.com>
-rw-r--r--README.md12
-rw-r--r--RELEASES.md6
-rw-r--r--docs/source/all.rst1
-rw-r--r--examples/others/plot_COOT.py97
-rw-r--r--examples/others/plot_learning_weights_with_COOT.py150
-rw-r--r--ot/coot.py434
-rw-r--r--test/test_coot.py359
7 files changed, 1052 insertions, 7 deletions
diff --git a/README.md b/README.md
index e7241b8..9c5e07e 100644
--- a/README.md
+++ b/README.md
@@ -276,15 +276,15 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[35] Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N., Koyejo, S., ... & Schwing, A. G. (2019). [Max-sliced wasserstein distance and its use for gans](https://openaccess.thecvf.com/content_CVPR_2019/papers/Deshpande_Max-Sliced_Wasserstein_Distance_and_Its_Use_for_GANs_CVPR_2019_paper.pdf). In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 10648-10656).
-[36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F. R.
-(2019, May). [Sliced-Wasserstein flows: Nonparametric generative modeling
-via optimal transport and diffusions](http://proceedings.mlr.press/v97/liutkus19a/liutkus19a.pdf). In International Conference on
+[36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F. R.
+(2019, May). [Sliced-Wasserstein flows: Nonparametric generative modeling
+via optimal transport and diffusions](http://proceedings.mlr.press/v97/liutkus19a/liutkus19a.pdf). In International Conference on
Machine Learning (pp. 4104-4113). PMLR.
[37] Janati, H., Cuturi, M., Gramfort, A. [Debiased sinkhorn barycenters](http://proceedings.mlr.press/v119/janati20a/janati20a.pdf) Proceedings of the 37th International
Conference on Machine Learning, PMLR 119:4692-4701, 2020
-[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph
+[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph
Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Conference on Machine Learning (ICML), 2021.
[39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). [Kantorovich duality for general transport costs and applications](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.712.1825&rep=rep1&type=pdf). Journal of Functional Analysis, 273(11), 3327-3405.
@@ -305,4 +305,6 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer
[47] Chowdhury, S., & Mémoli, F. (2019). [The gromov–wasserstein distance between networks and stable network invariants](https://academic.oup.com/imaiai/article/8/4/757/5627736). Information and Inference: A Journal of the IMA, 8(4), 757-787.
-[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty (2022). [Semi-relaxed Gromov-Wasserstein divergence and applications on graphs](https://openreview.net/pdf?id=RShaMexjc-x). International Conference on Learning Representations (ICLR), 2022. \ No newline at end of file
+[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty (2022). [Semi-relaxed Gromov-Wasserstein divergence and applications on graphs](https://openreview.net/pdf?id=RShaMexjc-x). International Conference on Learning Representations (ICLR), 2022.
+
+[49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020). [CO-Optimal Transport](https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf). Advances in Neural Information Processing Systems, 33.
diff --git a/RELEASES.md b/RELEASES.md
index e4c6e15..bc0b189 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -16,8 +16,10 @@
- New API for OT solver using function `ot.solve` (PR #388)
- Backend version of `ot.partial` and `ot.smooth` (PR #388 and #449)
- Added argument for warmstart of dual potentials in Sinkhorn-based methods in `ot.bregman` (PR #437)
-- Add parameters method in `ot.da.SinkhornTransport` (PR #440)
-- `ot.dr` now uses the new Pymanopt API and POT is compatible with current Pymanopt (PR #443)
+- Added parameters method in `ot.da.SinkhornTransport` (PR #440)
+- `ot.dr` now uses the new Pymanopt API and POT is compatible with current
+ Pymanopt (PR #443)
+- Added CO-Optimal Transport solver + examples (PR # 447)
- Remove the redundant `nx.abs()` at the end of `wasserstein_1d()` (PR #448)
#### Closed issues
diff --git a/docs/source/all.rst b/docs/source/all.rst
index 41d8e06..1b8d13c 100644
--- a/docs/source/all.rst
+++ b/docs/source/all.rst
@@ -16,6 +16,7 @@ API and modules
backend
bregman
+ coot
da
datasets
dr
diff --git a/examples/others/plot_COOT.py b/examples/others/plot_COOT.py
new file mode 100644
index 0000000..98c1ce1
--- /dev/null
+++ b/examples/others/plot_COOT.py
@@ -0,0 +1,97 @@
+# -*- coding: utf-8 -*-
+r"""
+===================================================
+Row and column alignments with CO-Optimal Transport
+===================================================
+
+This example is designed to show how to use the CO-Optimal Transport [47]_ in POT.
+CO-Optimal Transport allows to calculate the distance between two **arbitrary-size**
+matrices, and to align their rows and columns. In this example, we consider two
+random matrices :math:`X_1` and :math:`X_2` defined by
+:math:`(X_1)_{i,j} = \cos(\frac{i}{n_1} \pi) + \cos(\frac{j}{d_1} \pi) + \sigma \mathcal N(0,1)`
+and :math:`(X_2)_{i,j} = \cos(\frac{i}{n_2} \pi) + \cos(\frac{j}{d_2} \pi) + \sigma \mathcal N(0,1)`.
+
+.. [49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020).
+ `CO-Optimal Transport <https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf>`_.
+ Advances in Neural Information Processing Systems, 33.
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+# Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
+# License: MIT License
+
+from matplotlib.patches import ConnectionPatch
+import matplotlib.pylab as pl
+import numpy as np
+from ot.coot import co_optimal_transport as coot
+from ot.coot import co_optimal_transport2 as coot2
+
+# %%
+# Generating two random matrices
+
+n1 = 20
+n2 = 10
+d1 = 16
+d2 = 8
+sigma = 0.2
+
+X1 = (
+ np.cos(np.arange(n1) * np.pi / n1)[:, None] +
+ np.cos(np.arange(d1) * np.pi / d1)[None, :] +
+ sigma * np.random.randn(n1, d1)
+)
+X2 = (
+ np.cos(np.arange(n2) * np.pi / n2)[:, None] +
+ np.cos(np.arange(d2) * np.pi / d2)[None, :] +
+ sigma * np.random.randn(n2, d2)
+)
+
+# %%
+# Visualizing the matrices
+
+pl.figure(1, (8, 5))
+pl.subplot(1, 2, 1)
+pl.imshow(X1)
+pl.title('$X_1$')
+
+pl.subplot(1, 2, 2)
+pl.imshow(X2)
+pl.title("$X_2$")
+
+pl.tight_layout()
+
+# %%
+# Visualizing the alignments of rows and columns, and calculating the CO-Optimal Transport distance
+
+pi_sample, pi_feature, log = coot(X1, X2, log=True, verbose=True)
+coot_distance = coot2(X1, X2)
+print('CO-Optimal Transport distance = {:.5f}'.format(coot_distance))
+
+fig = pl.figure(4, (9, 7))
+pl.clf()
+
+ax1 = pl.subplot(2, 2, 3)
+pl.imshow(X1)
+pl.xlabel('$X_1$')
+
+ax2 = pl.subplot(2, 2, 2)
+ax2.yaxis.tick_right()
+pl.imshow(np.transpose(X2))
+pl.title("Transpose($X_2$)")
+ax2.xaxis.tick_top()
+
+for i in range(n1):
+ j = np.argmax(pi_sample[i, :])
+ xyA = (d1 - .5, i)
+ xyB = (j, d2 - .5)
+ con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA=ax1.transData,
+ coordsB=ax2.transData, color="black")
+ fig.add_artist(con)
+
+for i in range(d1):
+ j = np.argmax(pi_feature[i, :])
+ xyA = (i, -.5)
+ xyB = (-.5, j)
+ con = ConnectionPatch(
+ xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue")
+ fig.add_artist(con)
diff --git a/examples/others/plot_learning_weights_with_COOT.py b/examples/others/plot_learning_weights_with_COOT.py
new file mode 100644
index 0000000..cb115c3
--- /dev/null
+++ b/examples/others/plot_learning_weights_with_COOT.py
@@ -0,0 +1,150 @@
+# -*- coding: utf-8 -*-
+r"""
+===============================================================
+Learning sample marginal distribution with CO-Optimal Transport
+===============================================================
+
+In this example, we illustrate how to estimate the sample marginal distribution which minimizes
+the CO-Optimal Transport distance [47]_ between two matrices. More precisely, given a source data
+:math:`(X, \mu_x^{(s)}, \mu_x^{(f)})` and a target matrix :math:`Y` associated with a fixed
+histogram on features :math:`\mu_y^{(f)}`, we want to solve the following problem
+
+.. math::
+ \min_{\mu_y^{(s)} \in \Delta} \text{COOT}\left( (X, \mu_x^{(s)}, \mu_x^{(f)}), (Y, \mu_y^{(s)}, \mu_y^{(f)}) \right)
+
+where :math:`\Delta` is the probability simplex. This minimization is done with a
+simple projected gradient descent in PyTorch. We use the automatic backend of POT that
+allows us to compute the CO-Optimal Transport distance with :func:`ot.coot.co_optimal_transport2`
+with differentiable losses.
+
+.. [49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020).
+ `CO-Optimal Transport <https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf>`_.
+ Advances in Neural Information Processing Systems, 33.
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+# Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
+# License: MIT License
+
+from matplotlib.patches import ConnectionPatch
+import torch
+import numpy as np
+
+import matplotlib.pyplot as pl
+import ot
+
+from ot.coot import co_optimal_transport as coot
+from ot.coot import co_optimal_transport2 as coot2
+
+
+# %%
+# Generate data
+# -------------
+# The source and clean target matrices are generated by
+# :math:`X_{i,j} = \cos(\frac{i}{n_1} \pi) + \cos(\frac{j}{d_1} \pi)` and
+# :math:`Y_{i,j} = \cos(\frac{i}{n_2} \pi) + \cos(\frac{j}{d_2} \pi)`.
+# The target matrix is then contaminated by adding 5 row outliers.
+# Intuitively, we expect that the estimated sample distribution should ignore these outliers,
+# i.e. their weights should be zero.
+
+np.random.seed(182)
+
+n1, d1 = 20, 16
+n2, d2 = 10, 8
+n = 15
+
+X = (
+ torch.cos(torch.arange(n1) * torch.pi / n1)[:, None] +
+ torch.cos(torch.arange(d1) * torch.pi / d1)[None, :]
+)
+
+# Generate clean target data mixed with outliers
+Y_noisy = torch.randn((n, d2)) * 10.0
+Y_noisy[:n2, :] = (
+ torch.cos(torch.arange(n2) * torch.pi / n2)[:, None] +
+ torch.cos(torch.arange(d2) * torch.pi / d2)[None, :]
+)
+Y = Y_noisy[:n2, :]
+
+X, Y_noisy, Y = X.double(), Y_noisy.double(), Y.double()
+
+fig, axes = pl.subplots(nrows=1, ncols=3, figsize=(12, 5))
+axes[0].imshow(X, vmin=-2, vmax=2)
+axes[0].set_title('$X$')
+
+axes[1].imshow(Y, vmin=-2, vmax=2)
+axes[1].set_title('Clean $Y$')
+
+axes[2].imshow(Y_noisy, vmin=-2, vmax=2)
+axes[2].set_title('Noisy $Y$')
+
+pl.tight_layout()
+
+# %%
+# Optimize the COOT distance with respect to the sample marginal distribution
+# ---------------------------------------------------------------------------
+
+losses = []
+lr = 1e-3
+niter = 1000
+
+b = torch.tensor(ot.unif(n), requires_grad=True)
+
+for i in range(niter):
+
+ loss = coot2(X, Y_noisy, wy_samp=b, log=False, verbose=False)
+ losses.append(float(loss))
+
+ loss.backward()
+
+ with torch.no_grad():
+ b -= lr * b.grad # gradient step
+ b[:] = ot.utils.proj_simplex(b) # projection on the simplex
+
+ b.grad.zero_()
+
+# Estimated sample marginal distribution and training loss curve
+pl.plot(losses[10:])
+pl.title('CO-Optimal Transport distance')
+
+print(f"Marginal distribution = {b.detach().numpy()}")
+
+# %%
+# Visualizing the row and column alignments with the estimated sample marginal distribution
+# -----------------------------------------------------------------------------------------
+#
+# Clearly, the learned marginal distribution completely and successfully ignores the 5 outliers.
+
+X, Y_noisy = X.numpy(), Y_noisy.numpy()
+b = b.detach().numpy()
+
+pi_sample, pi_feature = coot(X, Y_noisy, wy_samp=b, log=False, verbose=True)
+
+fig = pl.figure(4, (9, 7))
+pl.clf()
+
+ax1 = pl.subplot(2, 2, 3)
+pl.imshow(X, vmin=-2, vmax=2)
+pl.xlabel('$X$')
+
+ax2 = pl.subplot(2, 2, 2)
+ax2.yaxis.tick_right()
+pl.imshow(np.transpose(Y_noisy), vmin=-2, vmax=2)
+pl.title("Transpose(Noisy $Y$)")
+ax2.xaxis.tick_top()
+
+for i in range(n1):
+ j = np.argmax(pi_sample[i, :])
+ xyA = (d1 - .5, i)
+ xyB = (j, d2 - .5)
+ con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA=ax1.transData,
+ coordsB=ax2.transData, color="black")
+ fig.add_artist(con)
+
+for i in range(d1):
+ j = np.argmax(pi_feature[i, :])
+ xyA = (i, -.5)
+ xyB = (-.5, j)
+ con = ConnectionPatch(
+ xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue")
+ fig.add_artist(con)
diff --git a/ot/coot.py b/ot/coot.py
new file mode 100644
index 0000000..66dd2c8
--- /dev/null
+++ b/ot/coot.py
@@ -0,0 +1,434 @@
+# -*- coding: utf-8 -*-
+"""
+CO-Optimal Transport solver
+"""
+
+# Author: Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
+#
+# License: MIT License
+
+import warnings
+from .lp import emd
+from .utils import list_to_array
+from .backend import get_backend
+from .bregman import sinkhorn
+
+
+def co_optimal_transport(X, Y, wx_samp=None, wx_feat=None, wy_samp=None, wy_feat=None,
+ epsilon=0, alpha=0, M_samp=None, M_feat=None,
+ warmstart=None, nits_bcd=100, tol_bcd=1e-7, eval_bcd=1,
+ nits_ot=500, tol_sinkhorn=1e-7, method_sinkhorn="sinkhorn",
+ early_stopping_tol=1e-6, log=False, verbose=False):
+ r"""Compute the CO-Optimal Transport between two matrices.
+
+ Return the sample and feature transport plans between
+ :math:`(\mathbf{X}, \mathbf{w}_{xs}, \mathbf{w}_{xf})` and
+ :math:`(\mathbf{Y}, \mathbf{w}_{ys}, \mathbf{w}_{yf})`.
+
+ The function solves the following CO-Optimal Transport (COOT) problem:
+
+ .. math::
+ \mathbf{COOT}_{\alpha, \varepsilon} = \mathop{\arg \min}_{\mathbf{P}, \mathbf{Q}}
+ &\quad \sum_{i,j,k,l}
+ (\mathbf{X}_{i,k} - \mathbf{Y}_{j,l})^2 \mathbf{P}_{i,j} \mathbf{Q}_{k,l}
+ + \alpha_s \sum_{i,j} \mathbf{P}_{i,j} \mathbf{M^{(s)}}_{i, j} \\
+ &+ \alpha_f \sum_{k, l} \mathbf{Q}_{k,l} \mathbf{M^{(f)}}_{k, l}
+ + \varepsilon_s \mathbf{KL}(\mathbf{P} | \mathbf{w}_{xs} \mathbf{w}_{ys}^T)
+ + \varepsilon_f \mathbf{KL}(\mathbf{Q} | \mathbf{w}_{xf} \mathbf{w}_{yf}^T)
+
+ Where :
+
+ - :math:`\mathbf{X}`: Data matrix in the source space
+ - :math:`\mathbf{Y}`: Data matrix in the target space
+ - :math:`\mathbf{M^{(s)}}`: Additional sample matrix
+ - :math:`\mathbf{M^{(f)}}`: Additional feature matrix
+ - :math:`\mathbf{w}_{xs}`: Distribution of the samples in the source space
+ - :math:`\mathbf{w}_{xf}`: Distribution of the features in the source space
+ - :math:`\mathbf{w}_{ys}`: Distribution of the samples in the target space
+ - :math:`\mathbf{w}_{yf}`: Distribution of the features in the target space
+
+ .. note:: This function allows epsilon to be zero.
+ In that case, the :any:`ot.lp.emd` solver of POT will be used.
+
+ Parameters
+ ----------
+ X : (n_sample_x, n_feature_x) array-like, float
+ First input matrix.
+ Y : (n_sample_y, n_feature_y) array-like, float
+ Second input matrix.
+ wx_samp : (n_sample_x, ) array-like, float, optional (default = None)
+ Histogram assigned on rows (samples) of matrix X.
+ Uniform distribution by default.
+ wx_feat : (n_feature_x, ) array-like, float, optional (default = None)
+ Histogram assigned on columns (features) of matrix X.
+ Uniform distribution by default.
+ wy_samp : (n_sample_y, ) array-like, float, optional (default = None)
+ Histogram assigned on rows (samples) of matrix Y.
+ Uniform distribution by default.
+ wy_feat : (n_feature_y, ) array-like, float, optional (default = None)
+ Histogram assigned on columns (features) of matrix Y.
+ Uniform distribution by default.
+ epsilon : scalar or indexable object of length 2, float or int, optional (default = 0)
+ Regularization parameters for entropic approximation of sample and feature couplings.
+ Allow the case where epsilon contains 0. In that case, the EMD solver is used instead of
+ Sinkhorn solver. If epsilon is scalar, then the same epsilon is applied to
+ both regularization of sample and feature couplings.
+ alpha : scalar or indexable object of length 2, float or int, optional (default = 0)
+ Coeffficient parameter of linear terms with respect to the sample and feature couplings.
+ If alpha is scalar, then the same alpha is applied to both linear terms.
+ M_samp : (n_sample_x, n_sample_y), float, optional (default = None)
+ Sample matrix with respect to the linear term on sample coupling.
+ M_feat : (n_feature_x, n_feature_y), float, optional (default = None)
+ Feature matrix with respect to the linear term on feature coupling.
+ warmstart : dictionary, optional (default = None)
+ Contains 4 keys:
+ - "duals_sample" and "duals_feature" whose values are
+ tuples of 2 vectors of size (n_sample_x, n_sample_y) and (n_feature_x, n_feature_y).
+ Initialization of sample and feature dual vectors
+ if using Sinkhorn algorithm. Zero vectors by default.
+
+ - "pi_sample" and "pi_feature" whose values are matrices
+ of size (n_sample_x, n_sample_y) and (n_feature_x, n_feature_y).
+ Initialization of sample and feature couplings.
+ Uniform distributions by default.
+ nits_bcd : int, optional (default = 100)
+ Number of Block Coordinate Descent (BCD) iterations to solve COOT.
+ tol_bcd : float, optional (default = 1e-7)
+ Tolerance of BCD scheme. If the L1-norm between the current and previous
+ sample couplings is under this threshold, then stop BCD scheme.
+ eval_bcd : int, optional (default = 1)
+ Multiplier of iteration at which the COOT cost is evaluated. For example,
+ if `eval_bcd = 8`, then the cost is calculated at iterations 8, 16, 24, etc...
+ nits_ot : int, optional (default = 100)
+ Number of iterations to solve each of the
+ two optimal transport problems in each BCD iteration.
+ tol_sinkhorn : float, optional (default = 1e-7)
+ Tolerance of Sinkhorn algorithm to stop the Sinkhorn scheme for
+ entropic optimal transport problem (if any) in each BCD iteration.
+ Only triggered when Sinkhorn solver is used.
+ method_sinkhorn : string, optional (default = "sinkhorn")
+ Method used in POT's `ot.sinkhorn` solver.
+ Only support "sinkhorn" and "sinkhorn_log".
+ early_stopping_tol : float, optional (default = 1e-6)
+ Tolerance for the early stopping. If the absolute difference between
+ the last 2 recorded COOT distances is under this tolerance, then stop BCD scheme.
+ log : bool, optional (default = False)
+ If True then the cost and 4 dual vectors, including
+ 2 from sample and 2 from feature couplings, are recorded.
+ verbose : bool, optional (default = False)
+ If True then print the COOT cost at every multiplier of `eval_bcd`-th iteration.
+
+ Returns
+ -------
+ pi_samp : (n_sample_x, n_sample_y) array-like, float
+ Sample coupling matrix.
+ pi_feat : (n_feature_x, n_feature_y) array-like, float
+ Feature coupling matrix.
+ log : dictionary, optional
+ Returned if `log` is True. The keys are:
+ duals_sample : (n_sample_x, n_sample_y) tuple, float
+ Pair of dual vectors when solving OT problem w.r.t the sample coupling.
+ duals_feature : (n_feature_x, n_feature_y) tuple, float
+ Pair of dual vectors when solving OT problem w.r.t the feature coupling.
+ distances : list, float
+ List of COOT distances.
+
+ References
+ ----------
+ .. [49] I. Redko, T. Vayer, R. Flamary, and N. Courty, CO-Optimal Transport,
+ Advances in Neural Information Processing ny_sampstems, 33 (2020).
+ """
+
+ def compute_kl(p, q):
+ kl = nx.sum(p * nx.log(p + 1.0 * (p == 0))) - nx.sum(p * nx.log(q))
+ return kl
+
+ # Main function
+
+ if method_sinkhorn not in ["sinkhorn", "sinkhorn_log"]:
+ raise ValueError(
+ "Method {} is not supported in CO-Optimal Transport.".format(method_sinkhorn))
+
+ X, Y = list_to_array(X, Y)
+ nx = get_backend(X, Y)
+
+ if isinstance(epsilon, float) or isinstance(epsilon, int):
+ eps_samp, eps_feat = epsilon, epsilon
+ else:
+ if len(epsilon) != 2:
+ raise ValueError("Epsilon must be either a scalar or an indexable object of length 2.")
+ else:
+ eps_samp, eps_feat = epsilon[0], epsilon[1]
+
+ if isinstance(alpha, float) or isinstance(alpha, int):
+ alpha_samp, alpha_feat = alpha, alpha
+ else:
+ if len(alpha) != 2:
+ raise ValueError("Alpha must be either a scalar or an indexable object of length 2.")
+ else:
+ alpha_samp, alpha_feat = alpha[0], alpha[1]
+
+ # constant input variables
+ if M_samp is None or alpha_samp == 0:
+ M_samp, alpha_samp = 0, 0
+ if M_feat is None or alpha_feat == 0:
+ M_feat, alpha_feat = 0, 0
+
+ nx_samp, nx_feat = X.shape
+ ny_samp, ny_feat = Y.shape
+
+ # measures on rows and columns
+ if wx_samp is None:
+ wx_samp = nx.ones(nx_samp, type_as=X) / nx_samp
+ if wx_feat is None:
+ wx_feat = nx.ones(nx_feat, type_as=X) / nx_feat
+ if wy_samp is None:
+ wy_samp = nx.ones(ny_samp, type_as=Y) / ny_samp
+ if wy_feat is None:
+ wy_feat = nx.ones(ny_feat, type_as=Y) / ny_feat
+
+ wxy_samp = wx_samp[:, None] * wy_samp[None, :]
+ wxy_feat = wx_feat[:, None] * wy_feat[None, :]
+
+ # pre-calculate cost constants
+ XY_sqr = (X ** 2 @ wx_feat)[:, None] + (Y ** 2 @
+ wy_feat)[None, :] + alpha_samp * M_samp
+ XY_sqr_T = ((X.T)**2 @ wx_samp)[:, None] + ((Y.T)
+ ** 2 @ wy_samp)[None, :] + alpha_feat * M_feat
+
+ # initialize coupling and dual vectors
+ if warmstart is None:
+ pi_samp, pi_feat = wxy_samp, wxy_feat # shape nx_samp x ny_samp and nx_feat x ny_feat
+ duals_samp = (nx.zeros(nx_samp, type_as=X), nx.zeros(
+ ny_samp, type_as=Y)) # shape nx_samp, ny_samp
+ duals_feat = (nx.zeros(nx_feat, type_as=X), nx.zeros(
+ ny_feat, type_as=Y)) # shape nx_feat, ny_feat
+ else:
+ pi_samp, pi_feat = warmstart["pi_sample"], warmstart["pi_feature"]
+ duals_samp, duals_feat = warmstart["duals_sample"], warmstart["duals_feature"]
+
+ # initialize log
+ list_coot = [float("inf")]
+ err = tol_bcd + 1e-3
+
+ for idx in range(nits_bcd):
+ pi_samp_prev = nx.copy(pi_samp)
+
+ # update sample coupling
+ ot_cost = XY_sqr - 2 * X @ pi_feat @ Y.T # size nx_samp x ny_samp
+ if eps_samp > 0:
+ pi_samp, dict_log = sinkhorn(a=wx_samp, b=wy_samp, M=ot_cost, reg=eps_samp, method=method_sinkhorn,
+ numItermax=nits_ot, stopThr=tol_sinkhorn, log=True, warmstart=duals_samp)
+ duals_samp = (nx.log(dict_log["u"]), nx.log(dict_log["v"]))
+ elif eps_samp == 0:
+ pi_samp, dict_log = emd(
+ a=wx_samp, b=wy_samp, M=ot_cost, numItermax=nits_ot, log=True)
+ duals_samp = (dict_log["u"], dict_log["v"])
+ # update feature coupling
+ ot_cost = XY_sqr_T - 2 * X.T @ pi_samp @ Y # size nx_feat x ny_feat
+ if eps_feat > 0:
+ pi_feat, dict_log = sinkhorn(a=wx_feat, b=wy_feat, M=ot_cost, reg=eps_feat, method=method_sinkhorn,
+ numItermax=nits_ot, stopThr=tol_sinkhorn, log=True, warmstart=duals_feat)
+ duals_feat = (nx.log(dict_log["u"]), nx.log(dict_log["v"]))
+ elif eps_feat == 0:
+ pi_feat, dict_log = emd(
+ a=wx_feat, b=wy_feat, M=ot_cost, numItermax=nits_ot, log=True)
+ duals_feat = (dict_log["u"], dict_log["v"])
+
+ if idx % eval_bcd == 0:
+ # update error
+ err = nx.sum(nx.abs(pi_samp - pi_samp_prev))
+
+ # COOT part
+ coot = nx.sum(ot_cost * pi_feat)
+ if alpha_samp != 0:
+ coot = coot + alpha_samp * nx.sum(M_samp * pi_samp)
+ # Entropic part
+ if eps_samp != 0:
+ coot = coot + eps_samp * compute_kl(pi_samp, wxy_samp)
+ if eps_feat != 0:
+ coot = coot + eps_feat * compute_kl(pi_feat, wxy_feat)
+ list_coot.append(coot)
+
+ if err < tol_bcd or abs(list_coot[-2] - list_coot[-1]) < early_stopping_tol:
+ break
+
+ if verbose:
+ print(
+ "CO-Optimal Transport cost at iteration {}: {}".format(idx + 1, coot))
+
+ # sanity check
+ if nx.sum(nx.isnan(pi_samp)) > 0 or nx.sum(nx.isnan(pi_feat)) > 0:
+ warnings.warn("There is NaN in coupling.")
+
+ if log:
+ dict_log = {"duals_sample": duals_samp,
+ "duals_feature": duals_feat,
+ "distances": list_coot[1:]}
+
+ return pi_samp, pi_feat, dict_log
+
+ else:
+ return pi_samp, pi_feat
+
+
+def co_optimal_transport2(X, Y, wx_samp=None, wx_feat=None, wy_samp=None, wy_feat=None,
+ epsilon=0, alpha=0, M_samp=None, M_feat=None,
+ warmstart=None, log=False, verbose=False, early_stopping_tol=1e-6,
+ nits_bcd=100, tol_bcd=1e-7, eval_bcd=1,
+ nits_ot=500, tol_sinkhorn=1e-7,
+ method_sinkhorn="sinkhorn"):
+ r"""Compute the CO-Optimal Transport distance between two measures.
+
+ Returns the CO-Optimal Transport distance between
+ :math:`(\mathbf{X}, \mathbf{w}_{xs}, \mathbf{w}_{xf})` and
+ :math:`(\mathbf{Y}, \mathbf{w}_{ys}, \mathbf{w}_{yf})`.
+
+ The function solves the following CO-Optimal Transport (COOT) problem:
+
+ .. math::
+ \mathbf{COOT}_{\alpha, \varepsilon} = \mathop{\arg \min}_{\mathbf{P}, \mathbf{Q}}
+ &\quad \sum_{i,j,k,l}
+ (\mathbf{X}_{i,k} - \mathbf{Y}_{j,l})^2 \mathbf{P}_{i,j} \mathbf{Q}_{k,l}
+ + \alpha_1 \sum_{i,j} \mathbf{P}_{i,j} \mathbf{M^{(s)}}_{i, j} \\
+ &+ \alpha_2 \sum_{k, l} \mathbf{Q}_{k,l} \mathbf{M^{(f)}}_{k, l}
+ + \varepsilon_1 \mathbf{KL}(\mathbf{P} | \mathbf{w}_{xs} \mathbf{w}_{ys}^T)
+ + \varepsilon_2 \mathbf{KL}(\mathbf{Q} | \mathbf{w}_{xf} \mathbf{w}_{yf}^T)
+
+ Where :
+
+ - :math:`\mathbf{X}`: Data matrix in the source space
+ - :math:`\mathbf{Y}`: Data matrix in the target space
+ - :math:`\mathbf{M^{(s)}}`: Additional sample matrix
+ - :math:`\mathbf{M^{(f)}}`: Additional feature matrix
+ - :math:`\mathbf{w}_{xs}`: Distribution of the samples in the source space
+ - :math:`\mathbf{w}_{xf}`: Distribution of the features in the source space
+ - :math:`\mathbf{w}_{ys}`: Distribution of the samples in the target space
+ - :math:`\mathbf{w}_{yf}`: Distribution of the features in the target space
+
+ .. note:: This function allows epsilon to be zero.
+ In that case, the :any:`ot.lp.emd` solver of POT will be used.
+
+ Parameters
+ ----------
+ X : (n_sample_x, n_feature_x) array-like, float
+ First input matrix.
+ Y : (n_sample_y, n_feature_y) array-like, float
+ Second input matrix.
+ wx_samp : (n_sample_x, ) array-like, float, optional (default = None)
+ Histogram assigned on rows (samples) of matrix X.
+ Uniform distribution by default.
+ wx_feat : (n_feature_x, ) array-like, float, optional (default = None)
+ Histogram assigned on columns (features) of matrix X.
+ Uniform distribution by default.
+ wy_samp : (n_sample_y, ) array-like, float, optional (default = None)
+ Histogram assigned on rows (samples) of matrix Y.
+ Uniform distribution by default.
+ wy_feat : (n_feature_y, ) array-like, float, optional (default = None)
+ Histogram assigned on columns (features) of matrix Y.
+ Uniform distribution by default.
+ epsilon : scalar or indexable object of length 2, float or int, optional (default = 0)
+ Regularization parameters for entropic approximation of sample and feature couplings.
+ Allow the case where epsilon contains 0. In that case, the EMD solver is used instead of
+ Sinkhorn solver. If epsilon is scalar, then the same epsilon is applied to
+ both regularization of sample and feature couplings.
+ alpha : scalar or indexable object of length 2, float or int, optional (default = 0)
+ Coeffficient parameter of linear terms with respect to the sample and feature couplings.
+ If alpha is scalar, then the same alpha is applied to both linear terms.
+ M_samp : (n_sample_x, n_sample_y), float, optional (default = None)
+ Sample matrix with respect to the linear term on sample coupling.
+ M_feat : (n_feature_x, n_feature_y), float, optional (default = None)
+ Feature matrix with respect to the linear term on feature coupling.
+ warmstart : dictionary, optional (default = None)
+ Contains 4 keys:
+ - "duals_sample" and "duals_feature" whose values are
+ tuples of 2 vectors of size (n_sample_x, n_sample_y) and (n_feature_x, n_feature_y).
+ Initialization of sample and feature dual vectors
+ if using Sinkhorn algorithm. Zero vectors by default.
+
+ - "pi_sample" and "pi_feature" whose values are matrices
+ of size (n_sample_x, n_sample_y) and (n_feature_x, n_feature_y).
+ Initialization of sample and feature couplings.
+ Uniform distributions by default.
+ nits_bcd : int, optional (default = 100)
+ Number of Block Coordinate Descent (BCD) iterations to solve COOT.
+ tol_bcd : float, optional (default = 1e-7)
+ Tolerance of BCD scheme. If the L1-norm between the current and previous
+ sample couplings is under this threshold, then stop BCD scheme.
+ eval_bcd : int, optional (default = 1)
+ Multiplier of iteration at which the COOT cost is evaluated. For example,
+ if `eval_bcd = 8`, then the cost is calculated at iterations 8, 16, 24, etc...
+ nits_ot : int, optional (default = 100)
+ Number of iterations to solve each of the
+ two optimal transport problems in each BCD iteration.
+ tol_sinkhorn : float, optional (default = 1e-7)
+ Tolerance of Sinkhorn algorithm to stop the Sinkhorn scheme for
+ entropic optimal transport problem (if any) in each BCD iteration.
+ Only triggered when Sinkhorn solver is used.
+ method_sinkhorn : string, optional (default = "sinkhorn")
+ Method used in POT's `ot.sinkhorn` solver.
+ Only support "sinkhorn" and "sinkhorn_log".
+ early_stopping_tol : float, optional (default = 1e-6)
+ Tolerance for the early stopping. If the absolute difference between
+ the last 2 recorded COOT distances is under this tolerance, then stop BCD scheme.
+ log : bool, optional (default = False)
+ If True then the cost and 4 dual vectors, including
+ 2 from sample and 2 from feature couplings, are recorded.
+ verbose : bool, optional (default = False)
+ If True then print the COOT cost at every multiplier of `eval_bcd`-th iteration.
+
+ Returns
+ -------
+ float
+ CO-Optimal Transport distance.
+ dict
+ Contains logged informations from :any:`co_optimal_transport` solver.
+ Only returned if `log` parameter is True
+
+ References
+ ----------
+ .. [47] I. Redko, T. Vayer, R. Flamary, and N. Courty, CO-Optimal Transport,
+ Advances in Neural Information Processing ny_sampstems, 33 (2020).
+ """
+
+ pi_samp, pi_feat, dict_log = co_optimal_transport(X=X, Y=Y, wx_samp=wx_samp, wx_feat=wx_feat, wy_samp=wy_samp,
+ wy_feat=wy_feat, epsilon=epsilon, alpha=alpha, M_samp=M_samp,
+ M_feat=M_feat, warmstart=warmstart, nits_bcd=nits_bcd,
+ tol_bcd=tol_bcd, eval_bcd=eval_bcd, nits_ot=nits_ot,
+ tol_sinkhorn=tol_sinkhorn, method_sinkhorn=method_sinkhorn,
+ early_stopping_tol=early_stopping_tol,
+ log=True, verbose=verbose)
+
+ X, Y = list_to_array(X, Y)
+ nx = get_backend(X, Y)
+
+ nx_samp, nx_feat = X.shape
+ ny_samp, ny_feat = Y.shape
+
+ # measures on rows and columns
+ if wx_samp is None:
+ wx_samp = nx.ones(nx_samp, type_as=X) / nx_samp
+ if wx_feat is None:
+ wx_feat = nx.ones(nx_feat, type_as=X) / nx_feat
+ if wy_samp is None:
+ wy_samp = nx.ones(ny_samp, type_as=Y) / ny_samp
+ if wy_feat is None:
+ wy_feat = nx.ones(ny_feat, type_as=Y) / ny_feat
+
+ vx_samp, vy_samp = dict_log["duals_sample"]
+ vx_feat, vy_feat = dict_log["duals_feature"]
+
+ gradX = 2 * X * (wx_samp[:, None] * wx_feat[None, :]) - \
+ 2 * pi_samp @ Y @ pi_feat.T # shape (nx_samp, nx_feat)
+ gradY = 2 * Y * (wy_samp[:, None] * wy_feat[None, :]) - \
+ 2 * pi_samp.T @ X @ pi_feat # shape (ny_samp, ny_feat)
+
+ coot = dict_log["distances"][-1]
+ coot = nx.set_gradients(coot, (wx_samp, wx_feat, wy_samp, wy_feat, X, Y),
+ (vx_samp, vx_feat, vy_samp, vy_feat, gradX, gradY))
+
+ if log:
+ return coot, dict_log
+
+ else:
+ return coot
diff --git a/test/test_coot.py b/test/test_coot.py
new file mode 100644
index 0000000..ef68a9b
--- /dev/null
+++ b/test/test_coot.py
@@ -0,0 +1,359 @@
+"""Tests for module COOT on OT """
+
+# Author: Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
+#
+# License: MIT License
+
+import numpy as np
+import ot
+from ot.coot import co_optimal_transport as coot
+from ot.coot import co_optimal_transport2 as coot2
+import pytest
+
+
+@pytest.mark.parametrize("verbose", [False, True, 1, 0])
+def test_coot(nx, verbose):
+ n_samples = 60 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(
+ n_samples, mu_s, cov_s, random_state=4)
+ xt = xs[::-1].copy()
+ xs_nx = nx.from_numpy(xs)
+ xt_nx = nx.from_numpy(xt)
+
+ # test couplings
+ pi_sample, pi_feature = coot(X=xs, Y=xt, verbose=verbose)
+ pi_sample_nx, pi_feature_nx = coot(X=xs_nx, Y=xt_nx, verbose=verbose)
+ pi_sample_nx = nx.to_numpy(pi_sample_nx)
+ pi_feature_nx = nx.to_numpy(pi_feature_nx)
+
+ anti_id_sample = np.flipud(np.eye(n_samples, n_samples)) / n_samples
+ id_feature = np.eye(2, 2) / 2
+
+ np.testing.assert_allclose(pi_sample, anti_id_sample, atol=1e-04)
+ np.testing.assert_allclose(pi_sample_nx, anti_id_sample, atol=1e-04)
+ np.testing.assert_allclose(pi_feature, id_feature, atol=1e-04)
+ np.testing.assert_allclose(pi_feature_nx, id_feature, atol=1e-04)
+
+ # test marginal distributions
+ px_s, px_f = ot.unif(n_samples), ot.unif(2)
+ py_s, py_f = ot.unif(n_samples), ot.unif(2)
+
+ np.testing.assert_allclose(px_s, pi_sample_nx.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_s, pi_sample_nx.sum(1), atol=1e-04)
+ np.testing.assert_allclose(px_f, pi_feature_nx.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_f, pi_feature_nx.sum(1), atol=1e-04)
+
+ np.testing.assert_allclose(px_s, pi_sample.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_s, pi_sample.sum(1), atol=1e-04)
+ np.testing.assert_allclose(px_f, pi_feature.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_f, pi_feature.sum(1), atol=1e-04)
+
+ # test COOT distance
+
+ coot_np = coot2(X=xs, Y=xt, verbose=verbose)
+ coot_nx = nx.to_numpy(coot2(X=xs_nx, Y=xt_nx, verbose=verbose))
+ np.testing.assert_allclose(coot_np, 0, atol=1e-08)
+ np.testing.assert_allclose(coot_nx, 0, atol=1e-08)
+
+
+def test_entropic_coot(nx):
+ n_samples = 60 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(
+ n_samples, mu_s, cov_s, random_state=4)
+ xt = xs[::-1].copy()
+ xs_nx = nx.from_numpy(xs)
+ xt_nx = nx.from_numpy(xt)
+
+ epsilon = (1, 1e-1)
+ nits_ot = 2000
+
+ # test couplings
+ pi_sample, pi_feature = coot(X=xs, Y=xt, epsilon=epsilon, nits_ot=nits_ot)
+ pi_sample_nx, pi_feature_nx = coot(
+ X=xs_nx, Y=xt_nx, epsilon=epsilon, nits_ot=nits_ot)
+ pi_sample_nx = nx.to_numpy(pi_sample_nx)
+ pi_feature_nx = nx.to_numpy(pi_feature_nx)
+
+ np.testing.assert_allclose(pi_sample, pi_sample_nx, atol=1e-04)
+ np.testing.assert_allclose(pi_feature, pi_feature_nx, atol=1e-04)
+
+ # test marginal distributions
+ px_s, px_f = ot.unif(n_samples), ot.unif(2)
+ py_s, py_f = ot.unif(n_samples), ot.unif(2)
+
+ np.testing.assert_allclose(px_s, pi_sample_nx.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_s, pi_sample_nx.sum(1), atol=1e-04)
+ np.testing.assert_allclose(px_f, pi_feature_nx.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_f, pi_feature_nx.sum(1), atol=1e-04)
+
+ np.testing.assert_allclose(px_s, pi_sample.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_s, pi_sample.sum(1), atol=1e-04)
+ np.testing.assert_allclose(px_f, pi_feature.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_f, pi_feature.sum(1), atol=1e-04)
+
+ # test entropic COOT distance
+
+ coot_np = coot2(X=xs, Y=xt, epsilon=epsilon, nits_ot=nits_ot)
+ coot_nx = nx.to_numpy(
+ coot2(X=xs_nx, Y=xt_nx, epsilon=epsilon, nits_ot=nits_ot))
+
+ np.testing.assert_allclose(coot_np, coot_nx, atol=1e-08)
+
+
+def test_coot_with_linear_terms(nx):
+ n_samples = 60 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(
+ n_samples, mu_s, cov_s, random_state=4)
+ xt = xs[::-1].copy()
+ xs_nx = nx.from_numpy(xs)
+ xt_nx = nx.from_numpy(xt)
+
+ M_samp = np.ones((n_samples, n_samples))
+ np.fill_diagonal(np.fliplr(M_samp), 0)
+ M_feat = np.ones((2, 2))
+ np.fill_diagonal(M_feat, 0)
+ M_samp_nx, M_feat_nx = nx.from_numpy(M_samp), nx.from_numpy(M_feat)
+
+ alpha = (1, 2)
+
+ # test couplings
+ anti_id_sample = np.flipud(np.eye(n_samples, n_samples)) / n_samples
+ id_feature = np.eye(2, 2) / 2
+
+ pi_sample, pi_feature = coot(
+ X=xs, Y=xt, alpha=alpha, M_samp=M_samp, M_feat=M_feat)
+ pi_sample_nx, pi_feature_nx = coot(
+ X=xs_nx, Y=xt_nx, alpha=alpha, M_samp=M_samp_nx, M_feat=M_feat_nx)
+ pi_sample_nx = nx.to_numpy(pi_sample_nx)
+ pi_feature_nx = nx.to_numpy(pi_feature_nx)
+
+ np.testing.assert_allclose(pi_sample, anti_id_sample, atol=1e-04)
+ np.testing.assert_allclose(pi_sample_nx, anti_id_sample, atol=1e-04)
+ np.testing.assert_allclose(pi_feature, id_feature, atol=1e-04)
+ np.testing.assert_allclose(pi_feature_nx, id_feature, atol=1e-04)
+
+ # test marginal distributions
+ px_s, px_f = ot.unif(n_samples), ot.unif(2)
+ py_s, py_f = ot.unif(n_samples), ot.unif(2)
+
+ np.testing.assert_allclose(px_s, pi_sample_nx.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_s, pi_sample_nx.sum(1), atol=1e-04)
+ np.testing.assert_allclose(px_f, pi_feature_nx.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_f, pi_feature_nx.sum(1), atol=1e-04)
+
+ np.testing.assert_allclose(px_s, pi_sample.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_s, pi_sample.sum(1), atol=1e-04)
+ np.testing.assert_allclose(px_f, pi_feature.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_f, pi_feature.sum(1), atol=1e-04)
+
+ # test COOT distance
+
+ coot_np = coot2(X=xs, Y=xt, alpha=alpha, M_samp=M_samp, M_feat=M_feat)
+ coot_nx = nx.to_numpy(
+ coot2(X=xs_nx, Y=xt_nx, alpha=alpha, M_samp=M_samp_nx, M_feat=M_feat_nx))
+ np.testing.assert_allclose(coot_np, 0, atol=1e-08)
+ np.testing.assert_allclose(coot_nx, 0, atol=1e-08)
+
+
+def test_coot_raise_value_error(nx):
+ n_samples = 80 # nb samples
+
+ mu_s = np.array([2, 4])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(
+ n_samples, mu_s, cov_s, random_state=43)
+ xt = xs[::-1].copy()
+ xs_nx = nx.from_numpy(xs)
+ xt_nx = nx.from_numpy(xt)
+
+ # raise value error of method sinkhorn
+ def coot_sh(method_sinkhorn):
+ return coot(X=xs, Y=xt, method_sinkhorn=method_sinkhorn)
+
+ def coot_sh_nx(method_sinkhorn):
+ return coot(X=xs_nx, Y=xt_nx, method_sinkhorn=method_sinkhorn)
+
+ np.testing.assert_raises(ValueError, coot_sh, "not_sinkhorn")
+ np.testing.assert_raises(ValueError, coot_sh_nx, "not_sinkhorn")
+
+ # raise value error for epsilon
+ def coot_eps(epsilon):
+ return coot(X=xs, Y=xt, epsilon=epsilon)
+
+ def coot_eps_nx(epsilon):
+ return coot(X=xs_nx, Y=xt_nx, epsilon=epsilon)
+
+ np.testing.assert_raises(ValueError, coot_eps, (1, 2, 3))
+ np.testing.assert_raises(ValueError, coot_eps_nx, [1, 2, 3, 4])
+
+ # raise value error for alpha
+ def coot_alpha(alpha):
+ return coot(X=xs, Y=xt, alpha=alpha)
+
+ def coot_alpha_nx(alpha):
+ return coot(X=xs_nx, Y=xt_nx, alpha=alpha)
+
+ np.testing.assert_raises(ValueError, coot_alpha, [1])
+ np.testing.assert_raises(ValueError, coot_alpha_nx, np.arange(4))
+
+
+def test_coot_warmstart(nx):
+ n_samples = 80 # nb samples
+
+ mu_s = np.array([2, 3])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(
+ n_samples, mu_s, cov_s, random_state=125)
+ xt = xs[::-1].copy()
+ xs_nx = nx.from_numpy(xs)
+ xt_nx = nx.from_numpy(xt)
+
+ # initialize warmstart
+ init_pi_sample = np.random.rand(n_samples, n_samples)
+ init_pi_sample = init_pi_sample / np.sum(init_pi_sample)
+ init_pi_sample_nx = nx.from_numpy(init_pi_sample)
+
+ init_pi_feature = np.random.rand(2, 2)
+ init_pi_feature /= init_pi_feature / np.sum(init_pi_feature)
+ init_pi_feature_nx = nx.from_numpy(init_pi_feature)
+
+ init_duals_sample = (np.random.random(n_samples) * 2 - 1,
+ np.random.random(n_samples) * 2 - 1)
+ init_duals_sample_nx = (nx.from_numpy(init_duals_sample[0]),
+ nx.from_numpy(init_duals_sample[1]))
+
+ init_duals_feature = (np.random.random(2) * 2 - 1,
+ np.random.random(2) * 2 - 1)
+ init_duals_feature_nx = (nx.from_numpy(init_duals_feature[0]),
+ nx.from_numpy(init_duals_feature[1]))
+
+ warmstart = {
+ "pi_sample": init_pi_sample,
+ "pi_feature": init_pi_feature,
+ "duals_sample": init_duals_sample,
+ "duals_feature": init_duals_feature
+ }
+
+ warmstart_nx = {
+ "pi_sample": init_pi_sample_nx,
+ "pi_feature": init_pi_feature_nx,
+ "duals_sample": init_duals_sample_nx,
+ "duals_feature": init_duals_feature_nx
+ }
+
+ # test couplings
+ pi_sample, pi_feature = coot(X=xs, Y=xt, warmstart=warmstart)
+ pi_sample_nx, pi_feature_nx = coot(
+ X=xs_nx, Y=xt_nx, warmstart=warmstart_nx)
+ pi_sample_nx = nx.to_numpy(pi_sample_nx)
+ pi_feature_nx = nx.to_numpy(pi_feature_nx)
+
+ anti_id_sample = np.flipud(np.eye(n_samples, n_samples)) / n_samples
+ id_feature = np.eye(2, 2) / 2
+
+ np.testing.assert_allclose(pi_sample, anti_id_sample, atol=1e-04)
+ np.testing.assert_allclose(pi_sample_nx, anti_id_sample, atol=1e-04)
+ np.testing.assert_allclose(pi_feature, id_feature, atol=1e-04)
+ np.testing.assert_allclose(pi_feature_nx, id_feature, atol=1e-04)
+
+ # test marginal distributions
+ px_s, px_f = ot.unif(n_samples), ot.unif(2)
+ py_s, py_f = ot.unif(n_samples), ot.unif(2)
+
+ np.testing.assert_allclose(px_s, pi_sample_nx.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_s, pi_sample_nx.sum(1), atol=1e-04)
+ np.testing.assert_allclose(px_f, pi_feature_nx.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_f, pi_feature_nx.sum(1), atol=1e-04)
+
+ np.testing.assert_allclose(px_s, pi_sample.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_s, pi_sample.sum(1), atol=1e-04)
+ np.testing.assert_allclose(px_f, pi_feature.sum(0), atol=1e-04)
+ np.testing.assert_allclose(py_f, pi_feature.sum(1), atol=1e-04)
+
+ # test COOT distance
+ coot_np = coot2(X=xs, Y=xt, warmstart=warmstart)
+ coot_nx = nx.to_numpy(coot2(X=xs_nx, Y=xt_nx, warmstart=warmstart_nx))
+ np.testing.assert_allclose(coot_np, 0, atol=1e-08)
+ np.testing.assert_allclose(coot_nx, 0, atol=1e-08)
+
+
+def test_coot_log(nx):
+ n_samples = 90 # nb samples
+
+ mu_s = np.array([-2, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(
+ n_samples, mu_s, cov_s, random_state=43)
+ xt = xs[::-1].copy()
+ xs_nx = nx.from_numpy(xs)
+ xt_nx = nx.from_numpy(xt)
+
+ pi_sample, pi_feature, log = coot(X=xs, Y=xt, log=True)
+ pi_sample_nx, pi_feature_nx, log_nx = coot(X=xs_nx, Y=xt_nx, log=True)
+
+ duals_sample, duals_feature = log["duals_sample"], log["duals_feature"]
+ assert len(duals_sample) == 2
+ assert len(duals_feature) == 2
+ assert len(duals_sample[0]) == n_samples
+ assert len(duals_sample[1]) == n_samples
+ assert len(duals_feature[0]) == 2
+ assert len(duals_feature[1]) == 2
+
+ duals_sample_nx = log_nx["duals_sample"]
+ assert len(duals_sample_nx) == 2
+ assert len(duals_sample_nx[0]) == n_samples
+ assert len(duals_sample_nx[1]) == n_samples
+
+ duals_feature_nx = log_nx["duals_feature"]
+ assert len(duals_feature_nx) == 2
+ assert len(duals_feature_nx[0]) == 2
+ assert len(duals_feature_nx[1]) == 2
+
+ list_coot = log["distances"]
+ assert len(list_coot) >= 1
+
+ list_coot_nx = log_nx["distances"]
+ assert len(list_coot_nx) >= 1
+
+ # test with coot distance
+ coot_np, log = coot2(X=xs, Y=xt, log=True)
+ coot_nx, log_nx = coot2(X=xs_nx, Y=xt_nx, log=True)
+
+ duals_sample, duals_feature = log["duals_sample"], log["duals_feature"]
+ assert len(duals_sample) == 2
+ assert len(duals_feature) == 2
+ assert len(duals_sample[0]) == n_samples
+ assert len(duals_sample[1]) == n_samples
+ assert len(duals_feature[0]) == 2
+ assert len(duals_feature[1]) == 2
+
+ duals_sample_nx = log_nx["duals_sample"]
+ assert len(duals_sample_nx) == 2
+ assert len(duals_sample_nx[0]) == n_samples
+ assert len(duals_sample_nx[1]) == n_samples
+
+ duals_feature_nx = log_nx["duals_feature"]
+ assert len(duals_feature_nx) == 2
+ assert len(duals_feature_nx[0]) == 2
+ assert len(duals_feature_nx[1]) == 2
+
+ list_coot = log["distances"]
+ assert len(list_coot) >= 1
+
+ list_coot_nx = log_nx["distances"]
+ assert len(list_coot_nx) >= 1