summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--docs/source/all.rst6
-rw-r--r--docs/source/readme.rst8
-rwxr-xr-xexamples/plot_partial_wass_and_gromov.py163
-rw-r--r--ot/__init__.py81
-rwxr-xr-xot/partial.py1014
-rwxr-xr-xtest/test_partial.py141
6 files changed, 1332 insertions, 81 deletions
diff --git a/docs/source/all.rst b/docs/source/all.rst
index c968aa1..a6d9790 100644
--- a/docs/source/all.rst
+++ b/docs/source/all.rst
@@ -86,3 +86,9 @@ ot.unbalanced
.. automodule:: ot.unbalanced
:members:
+
+ot.partial
+-------------
+
+.. automodule:: ot.partial
+ :members:
diff --git a/docs/source/readme.rst b/docs/source/readme.rst
index 0871779..d5f2161 100644
--- a/docs/source/readme.rst
+++ b/docs/source/readme.rst
@@ -391,6 +391,14 @@ of the 36th International Conference on Machine Learning (ICML).
`Learning with a Wasserstein Loss <http://cbcl.mit.edu/wasserstein/>`__
Advances in Neural Information Processing Systems (NIPS).
+[26] Caffarelli, L. A., McCann, R. J. (2020). `Free boundaries in optimal transport and
+Monge-Ampere obstacle problems <http://www.math.toronto.edu/~mccann/papers/annals2010.pdf>`__,
+Annals of mathematics, 673-730.
+
+[27] Chapel, L., Alaya, M., Gasso, G. (2019). `Partial Gromov-Wasserstein with Applications
+on Positive-Unlabeled Learning <https://arxiv.org/abs/2002.08276>`__. arXiv preprint
+arXiv:2002.08276.
+
.. |PyPI version| image:: https://badge.fury.io/py/POT.svg
:target: https://badge.fury.io/py/POT
.. |Anaconda Cloud| image:: https://anaconda.org/conda-forge/pot/badges/version.svg
diff --git a/examples/plot_partial_wass_and_gromov.py b/examples/plot_partial_wass_and_gromov.py
new file mode 100755
index 0000000..30b3fc0
--- /dev/null
+++ b/examples/plot_partial_wass_and_gromov.py
@@ -0,0 +1,163 @@
+# -*- coding: utf-8 -*-
+"""
+==========================
+Partial Wasserstein and Gromov-Wasserstein example
+==========================
+
+This example is designed to show how to use the Partial (Gromov-)Wassertsein
+distance computation in POT.
+"""
+
+# Author: Laetitia Chapel <laetitia.chapel@irisa.fr>
+# License: MIT License
+
+import scipy as sp
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+
+
+#############################################################################
+#
+# Sample two 2D Gaussian distributions and plot them
+# --------------------------------------------------
+#
+# For demonstration purpose, we sample two Gaussian distributions in 2-d
+# spaces and add some random noise.
+
+
+n_samples = 20 # nb samples (gaussian)
+n_noise = 20 # nb of samples (noise)
+
+mu = np.array([0, 0])
+cov = np.array([[1, 0], [0, 2]])
+
+xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
+xs = np.append(xs, (np.random.rand(n_noise, 2) + 1) * 4).reshape((-1, 2))
+xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
+xt = np.append(xt, (np.random.rand(n_noise, 2) + 1) * -3).reshape((-1, 2))
+
+M = sp.spatial.distance.cdist(xs, xt)
+
+fig = pl.figure()
+ax1 = fig.add_subplot(131)
+ax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ax2 = fig.add_subplot(132)
+ax2.scatter(xt[:, 0], xt[:, 1], color='r')
+ax3 = fig.add_subplot(133)
+ax3.imshow(M)
+pl.show()
+
+#############################################################################
+#
+# Compute partial Wasserstein plans and distance,
+# by transporting 50% of the mass
+# ----------------------------------------------
+
+p = ot.unif(n_samples + n_noise)
+q = ot.unif(n_samples + n_noise)
+
+w0, log0 = ot.partial.partial_wasserstein(p, q, M, m=0.5, log=True)
+w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=0.1, m=0.5,
+ log=True)
+
+print('Partial Wasserstein distance (m = 0.5): ' + str(log0['partial_w_dist']))
+print('Entropic partial Wasserstein distance (m = 0.5): ' +
+ str(log['partial_w_dist']))
+
+pl.figure(1, (10, 5))
+pl.subplot(1, 2, 1)
+pl.imshow(w0, cmap='jet')
+pl.title('Partial Wasserstein')
+pl.subplot(1, 2, 2)
+pl.imshow(w, cmap='jet')
+pl.title('Entropic partial Wasserstein')
+pl.show()
+
+
+#############################################################################
+#
+# Sample one 2D and 3D Gaussian distributions and plot them
+# ---------------------------------------------------------
+#
+# The Gromov-Wasserstein distance allows to compute distances with samples that
+# do not belong to the same metric space. For demonstration purpose, we sample
+# two Gaussian distributions in 2- and 3-dimensional spaces.
+
+n_samples = 20 # nb samples
+n_noise = 10 # nb of samples (noise)
+
+p = ot.unif(n_samples + n_noise)
+q = ot.unif(n_samples + n_noise)
+
+mu_s = np.array([0, 0])
+cov_s = np.array([[1, 0], [0, 1]])
+
+mu_t = np.array([0, 0, 0])
+cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
+
+
+xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
+xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) + 1) * 4)), axis=0)
+P = sp.linalg.sqrtm(cov_t)
+xt = np.random.randn(n_samples, 3).dot(P) + mu_t
+xt = np.concatenate((xt, ((np.random.rand(n_noise, 3) + 1) * 10)), axis=0)
+
+fig = pl.figure()
+ax1 = fig.add_subplot(121)
+ax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ax2 = fig.add_subplot(122, projection='3d')
+ax2.scatter(xt[:, 0], xt[:, 1], xt[:, 2], color='r')
+pl.show()
+
+
+#############################################################################
+#
+# Compute partial Gromov-Wasserstein plans and distance,
+# by transporting 100% and 2/3 of the mass
+# -----------------------------------------------------
+
+C1 = sp.spatial.distance.cdist(xs, xs)
+C2 = sp.spatial.distance.cdist(xt, xt)
+
+print('-----m = 1')
+m = 1
+res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m,
+ log=True)
+res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10,
+ m=m, log=True)
+
+print('Partial Wasserstein distance (m = 1): ' + str(log0['partial_gw_dist']))
+print('Entropic partial Wasserstein distance (m = 1): ' +
+ str(log['partial_gw_dist']))
+
+pl.figure(1, (10, 5))
+pl.title("mass to be transported m = 1")
+pl.subplot(1, 2, 1)
+pl.imshow(res0, cmap='jet')
+pl.title('Partial Wasserstein')
+pl.subplot(1, 2, 2)
+pl.imshow(res, cmap='jet')
+pl.title('Entropic partial Wasserstein')
+pl.show()
+
+print('-----m = 2/3')
+m = 2 / 3
+res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True)
+res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10,
+ m=m, log=True)
+
+print('Partial Wasserstein distance (m = 2/3): ' +
+ str(log0['partial_gw_dist']))
+print('Entropic partial Wasserstein distance (m = 2/3): ' +
+ str(log['partial_gw_dist']))
+
+pl.figure(1, (10, 5))
+pl.title("mass to be transported m = 2/3")
+pl.subplot(1, 2, 1)
+pl.imshow(res0, cmap='jet')
+pl.title('Partial Wasserstein')
+pl.subplot(1, 2, 2)
+pl.imshow(res, cmap='jet')
+pl.title('Entropic partial Wasserstein')
+pl.show()
diff --git a/ot/__init__.py b/ot/__init__.py
deleted file mode 100644
index 89c7936..0000000
--- a/ot/__init__.py
+++ /dev/null
@@ -1,81 +0,0 @@
-"""
-
-This is the main module of the POT toolbox. It provides easy access to
-a number of sub-modules and functions described below.
-
-.. note::
-
-
- Here is a list of the submodules and short description of what they contain.
-
- - :any:`ot.lp` contains OT solvers for the exact (Linear Program) OT problems.
- - :any:`ot.bregman` contains OT solvers for the entropic OT problems using
- Bregman projections.
- - :any:`ot.lp` contains OT solvers for the exact (Linear Program) OT problems.
- - :any:`ot.smooth` contains OT solvers for the regularized (l2 and kl) smooth OT
- problems.
- - :any:`ot.gromov` contains solvers for Gromov-Wasserstein and Fused Gromov
- Wasserstein problems.
- - :any:`ot.optim` contains generic solvers OT based optimization problems
- - :any:`ot.da` contains classes and function related to Monge mapping
- estimation and Domain Adaptation (DA).
- - :any:`ot.gpu` contains GPU (cupy) implementation of some OT solvers
- - :any:`ot.dr` contains Dimension Reduction (DR) methods such as Wasserstein
- Discriminant Analysis.
- - :any:`ot.utils` contains utility functions such as distance computation and
- timing.
- - :any:`ot.datasets` contains toy dataset generation functions.
- - :any:`ot.plot` contains visualization functions
- - :any:`ot.stochastic` contains stochastic solvers for regularized OT.
- - :any:`ot.unbalanced` contains solvers for regularized unbalanced OT.
-
-.. warning::
- The list of automatically imported sub-modules is as follows:
- :py:mod:`ot.lp`, :py:mod:`ot.bregman`, :py:mod:`ot.optim`
- :py:mod:`ot.utils`, :py:mod:`ot.datasets`,
- :py:mod:`ot.gromov`, :py:mod:`ot.smooth`
- :py:mod:`ot.stochastic`
-
- The following sub-modules are not imported due to additional dependencies:
-
- - :any:`ot.dr` : depends on :code:`pymanopt` and :code:`autograd`.
- - :any:`ot.gpu` : depends on :code:`cupy` and a CUDA GPU.
- - :any:`ot.plot` : depends on :code:`matplotlib`
-
-"""
-
-# Author: Remi Flamary <remi.flamary@unice.fr>
-# Nicolas Courty <ncourty@irisa.fr>
-#
-# License: MIT License
-
-
-# All submodules and packages
-from . import lp
-from . import bregman
-from . import optim
-from . import utils
-from . import datasets
-from . import da
-from . import gromov
-from . import smooth
-from . import stochastic
-from . import unbalanced
-
-# OT functions
-from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
-from .bregman import sinkhorn, sinkhorn2, barycenter
-from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2
-from .da import sinkhorn_lpl1_mm
-
-# utils functions
-from .utils import dist, unif, tic, toc, toq
-
-__version__ = "0.6.0"
-
-__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets',
- 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
- 'emd_1d', 'emd2_1d', 'wasserstein_1d',
- 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
- 'sinkhorn_unbalanced', 'barycenter_unbalanced',
- 'sinkhorn_unbalanced2']
diff --git a/ot/partial.py b/ot/partial.py
new file mode 100755
index 0000000..3425acb
--- /dev/null
+++ b/ot/partial.py
@@ -0,0 +1,1014 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+Partial OT
+"""
+
+# Author: Laetitia Chapel <laetitia.chapel@irisa.fr>
+# License: MIT License
+
+import numpy as np
+
+from .lp import emd
+
+
+def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
+ **kwargs):
+ r"""
+ Solves the partial optimal transport problem for the quadratic cost
+ and returns the OT plan
+
+ The function considers the following problem:
+
+ .. math::
+ \gamma = \arg\min_\gamma <\gamma,(M-\lambda)>_F
+
+ s.t.
+ \gamma\geq 0 \\
+ \gamma 1 \leq a\\
+ \gamma^T 1 \leq b\\
+ 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\}
+
+
+ or equivalently:
+
+ .. math::
+ \gamma = \arg\min_\gamma <\gamma,M>_F + \sqrt(\lambda/2)
+ (\|\gamma 1 - a\|_1 + \|\gamma^T 1 - b\|_1)
+
+ s.t.
+ \gamma\geq 0 \\
+
+
+ where :
+
+ - M is the metric cost matrix
+ - a and b are source and target unbalanced distributions
+ - :math:`\lambda` is the lagragian cost. Tuning its value allows attaining
+ a given mass to be transported m
+
+ The formulation of the problem has been proposed in [26]_
+
+ Parameters
+ ----------
+ a : np.ndarray (dim_a,)
+ Unnormalized histogram of dimension dim_a
+ b : np.ndarray (dim_b,)
+ Unnormalized histograms of dimension dim_b
+ M : np.ndarray (dim_a, dim_b)
+ cost matrix for the quadratic cost
+ reg_m : float, optional
+ Lagragian cost
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : (dim_a x dim_b) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary returned only if `log` is `True`
+
+
+ Examples
+ --------
+
+ >>> import ot
+ >>> a = [.1, .2]
+ >>> b = [.1, .1]
+ >>> M = [[0., 1.], [2., 3.]]
+ >>> np.round(partial_wasserstein_lagrange(a,b,M), 2)
+ array([[0.1, 0. ],
+ [0. , 0.1]])
+ >>> np.round(partial_wasserstein_lagrange(a,b,M,reg_m=2), 2)
+ array([[0.1, 0. ],
+ [0. , 0. ]])
+
+ References
+ ----------
+
+ .. [26] Caffarelli, L. A., & McCann, R. J. (2010) Free boundaries in
+ optimal transport and Monge-Ampere obstacle problems. Annals of
+ mathematics, 673-730.
+
+ See Also
+ --------
+ ot.partial.partial_wasserstein : Partial Wasserstein with fixed mass
+ """
+
+ if np.sum(a) > 1 or np.sum(b) > 1:
+ raise ValueError("Problem infeasible. Check that a and b are in the "
+ "simplex")
+
+ if reg_m is None:
+ reg_m = np.max(M) + 1
+ if reg_m < -np.max(M):
+ return np.zeros((len(a), len(b)))
+
+ eps = 1e-20
+ M = np.asarray(M, dtype=np.float64)
+ b = np.asarray(b, dtype=np.float64)
+ a = np.asarray(a, dtype=np.float64)
+
+ M_star = M - reg_m # modified cost matrix
+
+ # trick to fasten the computation: select only the subset of columns/lines
+ # that can have marginals greater than 0 (that is to say M < 0)
+ idx_x = np.where(np.min(M_star, axis=1) < eps)[0]
+ idx_y = np.where(np.min(M_star, axis=0) < eps)[0]
+
+ # extend a, b, M with "reservoir" or "dummy" points
+ M_extended = np.zeros((len(idx_x) + nb_dummies, len(idx_y) + nb_dummies))
+ M_extended[:len(idx_x), :len(idx_y)] = M_star[np.ix_(idx_x, idx_y)]
+
+ a_extended = np.append(a[idx_x], [(np.sum(a) - np.sum(a[idx_x]) +
+ np.sum(b)) / nb_dummies] * nb_dummies)
+ b_extended = np.append(b[idx_y], [(np.sum(b) - np.sum(b[idx_y]) +
+ np.sum(a)) / nb_dummies] * nb_dummies)
+
+ gamma_extended, log_emd = emd(a_extended, b_extended, M_extended, log=True,
+ **kwargs)
+ gamma = np.zeros((len(a), len(b)))
+ gamma[np.ix_(idx_x, idx_y)] = gamma_extended[:-nb_dummies, :-nb_dummies]
+
+ if log_emd['warning'] is not None:
+ raise ValueError("Error in the EMD resolution: try to increase the"
+ " number of dummy points")
+ log_emd['cost'] = np.sum(gamma * M)
+ if log:
+ return gamma, log_emd
+ else:
+ return gamma
+
+
+def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
+ r"""
+ Solves the partial optimal transport problem for the quadratic cost
+ and returns the OT plan
+
+ The function considers the following problem:
+
+ .. math::
+ \gamma = \arg\min_\gamma <\gamma,M>_F
+
+ s.t.
+ \gamma\geq 0 \\
+ \gamma 1 \leq a\\
+ \gamma^T 1 \leq b\\
+ 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\}
+
+
+ where :
+
+ - M is the metric cost matrix
+ - a and b are source and target unbalanced distributions
+ - m is the amount of mass to be transported
+
+ Parameters
+ ----------
+ a : np.ndarray (dim_a,)
+ Unnormalized histogram of dimension dim_a
+ b : np.ndarray (dim_b,)
+ Unnormalized histograms of dimension dim_b
+ M : np.ndarray (dim_a, dim_b)
+ cost matrix for the quadratic cost
+ m : float, optional
+ amount of mass to be transported
+ nb_dummies : int, optional, default:1
+ number of reservoir points to be added (to avoid numerical
+ instabilities, increase its value if an error is raised)
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ :math:`gamma` : (dim_a x dim_b) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary returned only if `log` is `True`
+
+
+ Examples
+ --------
+
+ >>> import ot
+ >>> a = [.1, .2]
+ >>> b = [.1, .1]
+ >>> M = [[0., 1.], [2., 3.]]
+ >>> np.round(partial_wasserstein(a,b,M), 2)
+ array([[0.1, 0. ],
+ [0. , 0.1]])
+ >>> np.round(partial_wasserstein(a,b,M,m=0.1), 2)
+ array([[0.1, 0. ],
+ [0. , 0. ]])
+
+ References
+ ----------
+ .. [26] Caffarelli, L. A., & McCann, R. J. (2010) Free boundaries in
+ optimal transport and Monge-Ampere obstacle problems. Annals of
+ mathematics, 673-730.
+ .. [27] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov-
+ Wasserstein with Applications on Positive-Unlabeled Learning".
+ arXiv preprint arXiv:2002.08276.
+
+ See Also
+ --------
+ ot.partial.partial_wasserstein_lagrange: Partial Wasserstein with
+ regularization on the marginals
+ ot.partial.entropic_partial_wasserstein: Partial Wasserstein with a
+ entropic regularization parameter
+ """
+
+ if m is None:
+ return partial_wasserstein_lagrange(a, b, M, log=log, **kwargs)
+ elif m < 0:
+ raise ValueError("Problem infeasible. Parameter m should be greater"
+ " than 0.")
+ elif m > np.min((np.sum(a), np.sum(b))):
+ raise ValueError("Problem infeasible. Parameter m should lower or"
+ " equal than min(|a|_1, |b|_1).")
+
+ b_extended = np.append(b, [(np.sum(a) - m) / nb_dummies] * nb_dummies)
+ a_extended = np.append(a, [(np.sum(b) - m) / nb_dummies] * nb_dummies)
+ M_extended = np.ones((len(a_extended), len(b_extended))) * 0
+ M_extended[-1, -1] = np.max(M) * 1e5
+ M_extended[:len(a), :len(b)] = M
+
+ gamma, log_emd = emd(a_extended, b_extended, M_extended, log=True,
+ **kwargs)
+ if log_emd['warning'] is not None:
+ raise ValueError("Error in the EMD resolution: try to increase the"
+ " number of dummy points")
+ log_emd['partial_w_dist'] = np.sum(M * gamma[:len(a), :len(b)])
+
+ if log:
+ return gamma[:len(a), :len(b)], log_emd
+ else:
+ return gamma[:len(a), :len(b)]
+
+
+def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
+ r"""
+ Solves the partial optimal transport problem for the quadratic cost
+ and returns the partial GW discrepancy
+
+ The function considers the following problem:
+
+ .. math::
+ \gamma = \arg\min_\gamma <\gamma,M>_F
+
+ s.t.
+ \gamma\geq 0 \\
+ \gamma 1 \leq a\\
+ \gamma^T 1 \leq b\\
+ 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\}
+
+
+ where :
+
+ - M is the metric cost matrix
+ - a and b are source and target unbalanced distributions
+ - m is the amount of mass to be transported
+
+ Parameters
+ ----------
+ a : np.ndarray (dim_a,)
+ Unnormalized histogram of dimension dim_a
+ b : np.ndarray (dim_b,)
+ Unnormalized histograms of dimension dim_b
+ M : np.ndarray (dim_a, dim_b)
+ cost matrix for the quadratic cost
+ m : float, optional
+ amount of mass to be transported
+ nb_dummies : int, optional, default:1
+ number of reservoir points to be added (to avoid numerical
+ instabilities, increase its value if an error is raised)
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ :math:`gamma` : (dim_a x dim_b) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary returned only if `log` is `True`
+
+
+ Examples
+ --------
+
+ >>> import ot
+ >>> a=[.1, .2]
+ >>> b=[.1, .1]
+ >>> M=[[0., 1.], [2., 3.]]
+ >>> np.round(partial_wasserstein2(a, b, M), 1)
+ 0.3
+ >>> np.round(partial_wasserstein2(a,b,M,m=0.1), 1)
+ 0.0
+
+ References
+ ----------
+ .. [26] Caffarelli, L. A., & McCann, R. J. (2010) Free boundaries in
+ optimal transport and Monge-Ampere obstacle problems. Annals of
+ mathematics, 673-730.
+ .. [27] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov-
+ Wasserstein with Applications on Positive-Unlabeled Learning".
+ arXiv preprint arXiv:2002.08276.
+ """
+
+ partial_gw, log_w = partial_wasserstein(a, b, M, m, nb_dummies, log=True,
+ **kwargs)
+
+ log_w['T'] = partial_gw
+
+ if log:
+ return np.sum(partial_gw * M), log_w
+ else:
+ return np.sum(partial_gw * M)
+
+
+def gwgrad_partial(C1, C2, T):
+ """Compute the GW gradient. Note: we can not use the trick in [12]_ as
+ the marginals may not sum to 1.
+
+ Parameters
+ ----------
+ C1: array of shape (n_p,n_p)
+ intra-source (P) cost matrix
+
+ C2: array of shape (n_u,n_u)
+ intra-target (U) cost matrix
+
+ T : array of shape(n_p+nb_dummies, n_u) (default: None)
+ Transport matrix
+
+ Returns
+ -------
+ numpy.array of shape (n_p+nb_dummies, n_u)
+ gradient
+
+ References
+ ----------
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+ """
+ cC1 = np.dot(C1 ** 2 / 2, np.dot(T, np.ones(C2.shape[0]).reshape(-1, 1)))
+ cC2 = np.dot(np.dot(np.ones(C1.shape[0]).reshape(1, -1), T), C2 ** 2 / 2)
+ constC = cC1 + cC2
+ A = -np.dot(C1, T).dot(C2.T)
+ tens = constC + A
+ return tens * 2
+
+
+def gwloss_partial(C1, C2, T):
+ """Compute the GW loss.
+
+ Parameters
+ ----------
+ C1: array of shape (n_p,n_p)
+ intra-source (P) cost matrix
+
+ C2: array of shape (n_u,n_u)
+ intra-target (U) cost matrix
+
+ T : array of shape(n_p+nb_dummies, n_u) (default: None)
+ Transport matrix
+
+ Returns
+ -------
+ GW loss
+ """
+ g = gwgrad_partial(C1, C2, T) * 0.5
+ return np.sum(g * T)
+
+
+def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
+ thres=1, numItermax=1000, tol=1e-7,
+ log=False, verbose=False, **kwargs):
+ r"""
+ Solves the partial optimal transport problem
+ and returns the OT plan
+
+ The function considers the following problem:
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F
+
+ s.t. \gamma 1 \leq a \\
+ \gamma^T 1 \leq b \\
+ \gamma\geq 0 \\
+ 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\} \\
+
+ where :
+
+ - M is the metric cost matrix
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)
+ =\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are the sample weights
+ - m is the amount of mass to be transported
+
+ The formulation of the problem has been proposed in [27]_
+
+
+ Parameters
+ ----------
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric costfr matrix in the target space
+ p : ndarray, shape (ns,)
+ Distribution in the source space
+ q : ndarray, shape (nt,)
+ Distribution in the target space
+ m : float, optional
+ Amount of mass to be transported (default: min (|p|_1, |q|_1))
+ nb_dummies : int, optional
+ Number of dummy points to add (avoid instabilities in the EMD solver)
+ G0 : ndarray, shape (ns, nt), optional
+ Initialisation of the transportation matrix
+ thres : float, optional
+ quantile of the gradient matrix to populate the cost matrix when 0
+ (default: 1)
+ numItermax : int, optional
+ Max number of iterations
+ log : bool, optional
+ return log if True
+ verbose : bool, optional
+ Print information along iterations
+ armijo : bool, optional
+ If True the steps of the line-search is found via an armijo research. Else closed form is used.
+ If there is convergence issues use False.
+ **kwargs : dict
+ parameters can be directly passed to the emd solver
+
+
+ Returns
+ -------
+ gamma : (dim_a x dim_b) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary returned only if `log` is `True`
+
+
+ Examples
+ --------
+ >>> import ot
+ >>> import scipy as sp
+ >>> a = np.array([0.25] * 4)
+ >>> b = np.array([0.25] * 4)
+ >>> x = np.array([1,2,100,200]).reshape((-1,1))
+ >>> y = np.array([3,2,98,199]).reshape((-1,1))
+ >>> C1 = sp.spatial.distance.cdist(x, x)
+ >>> C2 = sp.spatial.distance.cdist(y, y)
+ >>> np.round(partial_gromov_wasserstein(C1, C2, a, b),2)
+ array([[0. , 0.25, 0. , 0. ],
+ [0.25, 0. , 0. , 0. ],
+ [0. , 0. , 0.25, 0. ],
+ [0. , 0. , 0. , 0.25]])
+ >>> np.round(partial_gromov_wasserstein(C1, C2, a, b, m=0.25),2)
+ array([[0. , 0. , 0. , 0. ],
+ [0. , 0. , 0. , 0. ],
+ [0. , 0. , 0. , 0. ],
+ [0. , 0. , 0. , 0.25]])
+
+ References
+ ----------
+ .. [27] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov-
+ Wasserstein with Applications on Positive-Unlabeled Learning".
+ arXiv preprint arXiv:2002.08276.
+
+ """
+
+ if m is None:
+ m = np.min((np.sum(p), np.sum(q)))
+ elif m < 0:
+ raise ValueError("Problem infeasible. Parameter m should be greater"
+ " than 0.")
+ elif m > np.min((np.sum(p), np.sum(q))):
+ raise ValueError("Problem infeasible. Parameter m should lower or"
+ " equal than min(|a|_1, |b|_1).")
+
+ if G0 is None:
+ G0 = np.outer(p, q)
+
+ dim_G_extended = (len(p) + nb_dummies, len(q) + nb_dummies)
+ q_extended = np.append(q, [(np.sum(p) - m) / nb_dummies] * nb_dummies)
+ p_extended = np.append(p, [(np.sum(q) - m) / nb_dummies] * nb_dummies)
+
+ cpt = 0
+ err = 1
+ eps = 1e-20
+ if log:
+ log = {'err': []}
+
+ while (err > tol and cpt < numItermax):
+
+ Gprev = G0
+
+ M = gwgrad_partial(C1, C2, G0)
+ M[M < eps] = np.quantile(M[M > eps], thres)
+
+ M_emd = np.ones(dim_G_extended) * np.max(M) * 1e2
+ M_emd[:len(p), :len(q)] = M
+ M_emd[-nb_dummies:, -nb_dummies:] = np.max(M) * 1e5
+ M_emd = np.asarray(M_emd, dtype=np.float64)
+
+ Gc, logemd = emd(p_extended, q_extended, M_emd, log=True, **kwargs)
+
+ if logemd['warning'] is not None:
+ raise ValueError("Error in the EMD resolution: try to increase the"
+ " number of dummy points")
+
+ G0 = Gc[:len(p), :len(q)]
+
+ if cpt % 10 == 0: # to speed up the computations
+ err = np.linalg.norm(G0 - Gprev)
+ if log:
+ log['err'].append(err)
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}|{:12s}'.format(
+ 'It.', 'Err', 'Loss') + '\n' + '-' * 31)
+ print('{:5d}|{:8e}|{:8e}'.format(cpt, err,
+ gwloss_partial(C1, C2, G0)))
+
+ cpt += 1
+
+ if log:
+ log['partial_gw_dist'] = gwloss_partial(C1, C2, G0)
+ return G0[:len(p), :len(q)], log
+ else:
+ return G0[:len(p), :len(q)]
+
+
+def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
+ thres=0.75, numItermax=1000, tol=1e-7,
+ log=False, verbose=False, **kwargs):
+ r"""
+ Solves the partial optimal transport problem
+ and returns the partial Gromov-Wasserstein discrepancy
+
+ The function considers the following problem:
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F
+
+ s.t. \gamma 1 \leq a \\
+ \gamma^T 1 \leq b \\
+ \gamma\geq 0 \\
+ 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\} \\
+
+ where :
+
+ - M is the metric cost matrix
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are the sample weights
+ - m is the amount of mass to be transported
+
+ The formulation of the problem has been proposed in [27]_
+
+
+ Parameters
+ ----------
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric costfr matrix in the target space
+ p : ndarray, shape (ns,)
+ Distribution in the source space
+ q : ndarray, shape (nt,)
+ Distribution in the target space
+ m : float, optional
+ Amount of mass to be transported (default: min (|p|_1, |q|_1))
+ nb_dummies : int, optional
+ Number of dummy points to add (avoid instabilities in the EMD solver)
+ G0 : ndarray, shape (ns, nt), optional
+ Initialisation of the transportation matrix
+ thres : float, optional
+ quantile of the gradient matrix to populate the cost matrix when 0
+ (default: 1)
+ numItermax : int, optional
+ Max number of iterations
+ log : bool, optional
+ return log if True
+ verbose : bool, optional
+ Print information along iterations
+ **kwargs : dict
+ parameters can be directly passed to the emd solver
+
+
+ Returns
+ -------
+ partial_gw_dist : (dim_a x dim_b) ndarray
+ partial GW discrepancy
+ log : dict
+ log dictionary returned only if `log` is `True`
+
+
+ Examples
+ --------
+ >>> import ot
+ >>> import scipy as sp
+ >>> a = np.array([0.25] * 4)
+ >>> b = np.array([0.25] * 4)
+ >>> x = np.array([1,2,100,200]).reshape((-1,1))
+ >>> y = np.array([3,2,98,199]).reshape((-1,1))
+ >>> C1 = sp.spatial.distance.cdist(x, x)
+ >>> C2 = sp.spatial.distance.cdist(y, y)
+ >>> np.round(partial_gromov_wasserstein2(C1, C2, a, b),2)
+ 1.69
+ >>> np.round(partial_gromov_wasserstein2(C1, C2, a, b, m=0.25),2)
+ 0.0
+
+ References
+ ----------
+ .. [27] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov-
+ Wasserstein with Applications on Positive-Unlabeled Learning".
+ arXiv preprint arXiv:2002.08276.
+
+ """
+
+ partial_gw, log_gw = partial_gromov_wasserstein(C1, C2, p, q, m,
+ nb_dummies, G0, thres,
+ numItermax, tol, True,
+ verbose, **kwargs)
+
+ log_gw['T'] = partial_gw
+
+ if log:
+ return log_gw['partial_gw_dist'], log_gw
+ else:
+ return log_gw['partial_gw_dist']
+
+
+def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
+ stopThr=1e-100, verbose=False, log=False):
+ r"""
+ Solves the partial optimal transport problem
+ and returns the OT plan
+
+ The function considers the following problem:
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
+ s.t. \gamma 1 \leq a \\
+ \gamma^T 1 \leq b \\
+ \gamma\geq 0 \\
+ 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\} \\
+
+ where :
+
+ - M is the metric cost matrix
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are the sample weights
+ - m is the amount of mass to be transported
+
+ The formulation of the problem has been proposed in [3]_
+
+
+ Parameters
+ ----------
+ a : np.ndarray (dim_a,)
+ Unnormalized histogram of dimension dim_a
+ b : np.ndarray (dim_b,)
+ Unnormalized histograms of dimension dim_b
+ M : np.ndarray (dim_a, dim_b)
+ cost matrix
+ reg : float
+ Regularization term > 0
+ m : float, optional
+ Amount of mass to be transported
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : (dim_a x dim_b) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary returned only if `log` is `True`
+
+
+ Examples
+ --------
+ >>> import ot
+ >>> a = [.1, .2]
+ >>> b = [.1, .1]
+ >>> M = [[0., 1.], [2., 3.]]
+ >>> np.round(entropic_partial_wasserstein(a, b, M, 1, 0.1), 2)
+ array([[0.06, 0.02],
+ [0.01, 0. ]])
+
+
+ References
+ ----------
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G.
+ (2015). Iterative Bregman projections for regularized transportation
+ problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+
+ See Also
+ --------
+ ot.partial.partial_wasserstein: exact Partial Wasserstein
+ """
+
+ a = np.asarray(a, dtype=np.float64)
+ b = np.asarray(b, dtype=np.float64)
+ M = np.asarray(M, dtype=np.float64)
+
+ dim_a, dim_b = M.shape
+ dx = np.ones(dim_a)
+ dy = np.ones(dim_b)
+
+ if len(a) == 0:
+ a = np.ones(dim_a, dtype=np.float64) / dim_a
+ if len(b) == 0:
+ b = np.ones(dim_b, dtype=np.float64) / dim_b
+
+ if m is None:
+ m = np.min((np.sum(a), np.sum(b)))
+ if m < 0:
+ raise ValueError("Problem infeasible. Parameter m should be greater"
+ " than 0.")
+ if m > np.min((np.sum(a), np.sum(b))):
+ raise ValueError("Problem infeasible. Parameter m should lower or"
+ " equal than min(|a|_1, |b|_1).")
+
+ log_e = {'err': []}
+
+ # Next 3 lines equivalent to K=np.exp(-M/reg), but faster to compute
+ K = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+ np.multiply(K, m / np.sum(K), out=K)
+
+ err, cpt = 1, 0
+
+ while (err > stopThr and cpt < numItermax):
+ Kprev = K
+ K1 = np.dot(np.diag(np.minimum(a / np.sum(K, axis=1), dx)), K)
+ K2 = np.dot(K1, np.diag(np.minimum(b / np.sum(K1, axis=0), dy)))
+ K = K2 * (m / np.sum(K2))
+
+ if np.any(np.isnan(K)) or np.any(np.isinf(K)):
+ print('Warning: numerical errors at iteration', cpt)
+ break
+ if cpt % 10 == 0:
+ err = np.linalg.norm(Kprev - K)
+ if log:
+ log_e['err'].append(err)
+ if verbose:
+ if cpt % 200 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 11)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt = cpt + 1
+ log_e['partial_w_dist'] = np.sum(M * K)
+ if log:
+ return K, log_e
+ else:
+ return K
+
+
+def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None,
+ numItermax=1000, tol=1e-7, log=False,
+ verbose=False):
+ r"""
+ Returns the partial Gromov-Wasserstein transport between (C1,p) and (C2,q)
+
+ The function solves the following optimization problem:
+
+ .. math::
+ GW = \arg\min_{\gamma} \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})\cdot
+ \gamma_{i,j}\cdot\gamma_{k,l} + reg\cdot\Omega(\gamma)
+
+ s.t.
+ \gamma\geq 0 \\
+ \gamma 1 \leq a\\
+ \gamma^T 1 \leq b\\
+ 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\}
+
+ where :
+
+ - C1 is the metric cost matrix in the source space
+ - C2 is the metric cost matrix in the target space
+ - p and q are the sample weights
+ - L : quadratic loss function
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - m is the amount of mass to be transported
+
+ The formulation of the problem has been proposed in [12].
+
+ Parameters
+ ----------
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric costfr matrix in the target space
+ p : ndarray, shape (ns,)
+ Distribution in the source space
+ q : ndarray, shape (nt,)
+ Distribution in the target space
+ reg: float
+ entropic regularization parameter
+ m : float, optional
+ Amount of mass to be transported (default: min (|p|_1, |q|_1))
+ G0 : ndarray, shape (ns, nt), optional
+ Initialisation of the transportation matrix
+ numItermax : int, optional
+ Max number of iterations
+ log : bool, optional
+ return log if True
+ verbose : bool, optional
+ Print information along iterations
+
+ Examples
+ --------
+ >>> import ot
+ >>> import scipy as sp
+ >>> a = np.array([0.25] * 4)
+ >>> b = np.array([0.25] * 4)
+ >>> x = np.array([1,2,100,200]).reshape((-1,1))
+ >>> y = np.array([3,2,98,199]).reshape((-1,1))
+ >>> C1 = sp.spatial.distance.cdist(x, x)
+ >>> C2 = sp.spatial.distance.cdist(y, y)
+ >>> np.round(entropic_partial_gromov_wasserstein(C1, C2, a, b,50), 2)
+ array([[0.12, 0.13, 0. , 0. ],
+ [0.13, 0.12, 0. , 0. ],
+ [0. , 0. , 0.25, 0. ],
+ [0. , 0. , 0. , 0.25]])
+ >>> np.round(entropic_partial_gromov_wasserstein(C1, C2, a, b, 50, m=0.25)
+ , 2)
+ array([[0.02, 0.03, 0. , 0.03],
+ [0.03, 0.03, 0. , 0.03],
+ [0. , 0. , 0.03, 0. ],
+ [0.02, 0.02, 0. , 0.03]])
+
+ Returns
+ -------
+ :math: `gamma` : (dim_a x dim_b) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary returned only if `log` is `True`
+
+ References
+ ----------
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ See Also
+ --------
+ ot.partial.partial_gromov_wasserstein: exact Partial Gromov-Wasserstein
+ """
+
+ if G0 is None:
+ G0 = np.outer(p, q)
+
+ if m is None:
+ m = np.min((np.sum(p), np.sum(q)))
+ elif m < 0:
+ raise ValueError("Problem infeasible. Parameter m should be greater"
+ " than 0.")
+ elif m > np.min((np.sum(p), np.sum(q))):
+ raise ValueError("Problem infeasible. Parameter m should lower or"
+ " equal than min(|a|_1, |b|_1).")
+
+ cpt = 0
+ err = 1
+
+ loge = {'err': []}
+
+ while (err > tol and cpt < numItermax):
+ Gprev = G0
+ M_entr = gwgrad_partial(C1, C2, G0)
+ G0 = entropic_partial_wasserstein(p, q, M_entr, reg, m)
+ if cpt % 10 == 0: # to speed up the computations
+ err = np.linalg.norm(G0 - Gprev)
+ if log:
+ loge['err'].append(err)
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}|{:12s}'.format(
+ 'It.', 'Err', 'Loss') + '\n' + '-' * 31)
+ print('{:5d}|{:8e}|{:8e}'.format(cpt, err,
+ gwloss_partial(C1, C2, G0)))
+
+ cpt += 1
+
+ if log:
+ loge['partial_gw_dist'] = gwloss_partial(C1, C2, G0)
+ return G0, loge
+ else:
+ return G0
+
+
+def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None,
+ numItermax=1000, tol=1e-7, log=False,
+ verbose=False):
+ r"""
+ Returns the partial Gromov-Wasserstein discrepancy between (C1,p) and
+ (C2,q)
+
+ The function solves the following optimization problem:
+
+ .. math::
+ GW = \arg\min_{\gamma} \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})\cdot
+ \gamma_{i,j}\cdot\gamma_{k,l} + reg\cdot\Omega(\gamma)
+
+ s.t.
+ \gamma\geq 0 \\
+ \gamma 1 \leq a\\
+ \gamma^T 1 \leq b\\
+ 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\}
+
+ where :
+
+ - C1 is the metric cost matrix in the source space
+ - C2 is the metric cost matrix in the target space
+ - p and q are the sample weights
+ - L : quadratic loss function
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - m is the amount of mass to be transported
+
+ The formulation of the problem has been proposed in [12].
+
+
+ Parameters
+ ----------
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric costfr matrix in the target space
+ p : ndarray, shape (ns,)
+ Distribution in the source space
+ q : ndarray, shape (nt,)
+ Distribution in the target space
+ reg: float
+ entropic regularization parameter
+ m : float, optional
+ Amount of mass to be transported (default: min (|p|_1, |q|_1))
+ G0 : ndarray, shape (ns, nt), optional
+ Initialisation of the transportation matrix
+ numItermax : int, optional
+ Max number of iterations
+ log : bool, optional
+ return log if True
+ verbose : bool, optional
+ Print information along iterations
+
+
+ Returns
+ -------
+ partial_gw_dist: float
+ Gromov-Wasserstein distance
+ log : dict
+ log dictionary returned only if `log` is `True`
+
+ Examples
+ --------
+ >>> import ot
+ >>> import scipy as sp
+ >>> a = np.array([0.25] * 4)
+ >>> b = np.array([0.25] * 4)
+ >>> x = np.array([1,2,100,200]).reshape((-1,1))
+ >>> y = np.array([3,2,98,199]).reshape((-1,1))
+ >>> C1 = sp.spatial.distance.cdist(x, x)
+ >>> C2 = sp.spatial.distance.cdist(y, y)
+ >>> np.round(entropic_partial_gromov_wasserstein2(C1, C2, a, b,50), 2)
+ 1.87
+
+ References
+ ----------
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+ """
+
+ partial_gw, log_gw = entropic_partial_gromov_wasserstein(C1, C2, p, q, reg,
+ m, G0, numItermax,
+ tol, True,
+ verbose)
+
+ log_gw['T'] = partial_gw
+
+ if log:
+ return log_gw['partial_gw_dist'], log_gw
+ else:
+ return log_gw['partial_gw_dist']
diff --git a/test/test_partial.py b/test/test_partial.py
new file mode 100755
index 0000000..fbcd3c2
--- /dev/null
+++ b/test/test_partial.py
@@ -0,0 +1,141 @@
+"""Tests for module partial """
+
+# Author:
+# Laetitia Chapel <laetitia.chapel@irisa.fr>
+#
+# License: MIT License
+
+import numpy as np
+import scipy as sp
+import ot
+
+
+def test_partial_wasserstein():
+
+ n_samples = 20 # nb samples (gaussian)
+ n_noise = 20 # nb of samples (noise)
+
+ mu = np.array([0, 0])
+ cov = np.array([[1, 0], [0, 2]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
+ xs = np.append(xs, (np.random.rand(n_noise, 2) + 1) * 4).reshape((-1, 2))
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
+ xt = np.append(xt, (np.random.rand(n_noise, 2) + 1) * -3).reshape((-1, 2))
+
+ M = ot.dist(xs, xt)
+
+ p = ot.unif(n_samples + n_noise)
+ q = ot.unif(n_samples + n_noise)
+
+ m = 0.5
+
+ w0, log0 = ot.partial.partial_wasserstein(p, q, M, m=m, log=True)
+ w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=m,
+ log=True)
+
+ # check constratints
+ np.testing.assert_equal(
+ w0.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein
+ np.testing.assert_equal(
+ w0.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein
+ np.testing.assert_equal(
+ w.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein
+ np.testing.assert_equal(
+ w.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein
+
+ # check transported mass
+ np.testing.assert_allclose(
+ np.sum(w0), m, atol=1e-04)
+ np.testing.assert_allclose(
+ np.sum(w), m, atol=1e-04)
+
+ w0, log0 = ot.partial.partial_wasserstein2(p, q, M, m=m, log=True)
+ w0_val = ot.partial.partial_wasserstein2(p, q, M, m=m, log=False)
+
+ G = log0['T']
+
+ np.testing.assert_allclose(w0, w0_val, atol=1e-1, rtol=1e-1)
+
+ # check constratints
+ np.testing.assert_equal(
+ G.sum(1) <= p, [True] * len(p)) # cf convergence wasserstein
+ np.testing.assert_equal(
+ G.sum(0) <= q, [True] * len(q)) # cf convergence wasserstein
+ np.testing.assert_allclose(
+ np.sum(G), m, atol=1e-04)
+
+
+def test_partial_gromov_wasserstein():
+ n_samples = 20 # nb samples
+ n_noise = 10 # nb of samples (noise)
+
+ p = ot.unif(n_samples + n_noise)
+ q = ot.unif(n_samples + n_noise)
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ mu_t = np.array([0, 0, 0])
+ cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
+ xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) + 1) * 4)), axis=0)
+ P = sp.linalg.sqrtm(cov_t)
+ xt = np.random.randn(n_samples, 3).dot(P) + mu_t
+ xt = np.concatenate((xt, ((np.random.rand(n_noise, 3) + 1) * 10)), axis=0)
+ xt2 = xs[::-1].copy()
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+ C3 = ot.dist(xt2, xt2)
+
+ m = 2 / 3
+ res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C3, p, q, m=m,
+ log=True)
+ res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C3, p, q, 10,
+ m=m, log=True)
+ np.testing.assert_allclose(res0, 0, atol=1e-1, rtol=1e-1)
+ np.testing.assert_allclose(res, 0, atol=1e-1, rtol=1e-1)
+
+ C1 = sp.spatial.distance.cdist(xs, xs)
+ C2 = sp.spatial.distance.cdist(xt, xt)
+
+ m = 1
+ res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m,
+ log=True)
+ G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss')
+ np.testing.assert_allclose(G, res0, atol=1e-04)
+
+ res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10,
+ m=m, log=True)
+ G = ot.gromov.entropic_gromov_wasserstein(
+ C1, C2, p, q, 'square_loss', epsilon=10)
+ np.testing.assert_allclose(G, res, atol=1e-02)
+
+ w0, log0 = ot.partial.partial_gromov_wasserstein2(C1, C2, p, q, m=m,
+ log=True)
+ w0_val = ot.partial.partial_gromov_wasserstein2(C1, C2, p, q, m=m,
+ log=False)
+ G = log0['T']
+ np.testing.assert_allclose(w0, w0_val, atol=1e-1, rtol=1e-1)
+
+ m = 2 / 3
+ res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m,
+ log=True)
+ res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10,
+ m=m, log=True)
+ # check constratints
+ np.testing.assert_equal(
+ res0.sum(1) <= p, [True] * len(p)) # cf convergence wasserstein
+ np.testing.assert_equal(
+ res0.sum(0) <= q, [True] * len(q)) # cf convergence wasserstein
+ np.testing.assert_allclose(
+ np.sum(res0), m, atol=1e-04)
+
+ np.testing.assert_equal(
+ res.sum(1) <= p, [True] * len(p)) # cf convergence wasserstein
+ np.testing.assert_equal(
+ res.sum(0) <= q, [True] * len(q)) # cf convergence wasserstein
+ np.testing.assert_allclose(
+ np.sum(res), m, atol=1e-04)