summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2022-02-02 11:53:12 +0100
committerGitHub <noreply@github.com>2022-02-02 11:53:12 +0100
commita5e0f0d40d5046a6639924347ef97e2ac80ad0c9 (patch)
treedcd35e851ec2cc3f52eedbfa58fb6970664135c9 /ot
parent71a57c68ea9eb2bc948c4dd1cce9928f34bf20e8 (diff)
[MRG] Add weak OT solver (#341)
* add info in release file * update tests * pep8 * add weak OT example * update plot in doc * correction ewample with empirical sinkhorn * better thumbnail * comment from review * update documenation
Diffstat (limited to 'ot')
-rw-r--r--ot/__init__.py5
-rw-r--r--ot/gromov.py16
-rw-r--r--ot/lp/__init__.py9
-rw-r--r--ot/lp/cvx.py1
-rw-r--r--ot/utils.py12
-rw-r--r--ot/weak.py124
6 files changed, 159 insertions, 8 deletions
diff --git a/ot/__init__.py b/ot/__init__.py
index 1ea7403..7253318 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -36,6 +36,7 @@ from . import unbalanced
from . import partial
from . import backend
from . import regpath
+from . import weak
# OT functions
from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
@@ -46,7 +47,7 @@ from .da import sinkhorn_lpl1_mm
from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance
from .gromov import (gromov_wasserstein, gromov_wasserstein2,
gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2)
-
+from .weak import weak_optimal_transport
# utils functions
from .utils import dist, unif, tic, toc, toq
@@ -59,5 +60,5 @@ __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils',
'sinkhorn_unbalanced', 'barycenter_unbalanced',
'sinkhorn_unbalanced2', 'sliced_wasserstein_distance',
'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2',
- 'max_sliced_wasserstein_distance',
+ 'max_sliced_wasserstein_distance', 'weak_optimal_transport',
'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath']
diff --git a/ot/gromov.py b/ot/gromov.py
index 6544260..b7e7949 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -338,6 +338,10 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F
- :math:`\mathbf{q}`: distribution in the target space
- `L`: loss function to account for the misfit between the similarity matrices
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
Parameters
----------
C1 : array-like, shape (ns, ns)
@@ -436,6 +440,10 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=
Note that when using backends, this loss function is differentiable wrt the
marices and weights for quadratic loss using the gradients from [38]_.
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
Parameters
----------
C1 : array-like, shape (ns, ns)
@@ -545,6 +553,10 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
- :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1)
- `L` is a loss function to account for the misfit between the similarity matrices
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] <references-fused-gromov-wasserstein>`
Parameters
@@ -645,6 +657,10 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
The algorithm used for solving the problem is conditional gradient as
discussed in :ref:`[24] <references-fused-gromov-wasserstein2>`
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
Note that when using backends, this loss function is differentiable wrt the
marices and weights for quadratic loss using the gradients from [38]_.
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 2ff7c1f..d9b6fa9 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -26,6 +26,8 @@ from ..utils import dist, list_to_array
from ..utils import parmap
from ..backend import get_backend
+
+
__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted',
'emd_1d', 'emd2_1d', 'wasserstein_1d']
@@ -220,7 +222,8 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
format
.. note:: This function is backend-compatible and will work on arrays
- from all compatible backends.
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
Uses the algorithm proposed in :ref:`[1] <references-emd>`.
@@ -358,7 +361,8 @@ def emd2(a, b, M, processes=1,
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
.. note:: This function is backend-compatible and will work on arrays
- from all compatible backends.
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
Uses the algorithm proposed in :ref:`[1] <references-emd2>`.
@@ -622,3 +626,4 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
return X, log_dict
else:
return X
+
diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py
index 869d450..fbf3c0e 100644
--- a/ot/lp/cvx.py
+++ b/ot/lp/cvx.py
@@ -11,7 +11,6 @@ import numpy as np
import scipy as sp
import scipy.sparse as sps
-
try:
import cvxopt
from cvxopt import solvers, matrix, spmatrix
diff --git a/ot/utils.py b/ot/utils.py
index e6c93c8..725ca00 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -116,7 +116,7 @@ def proj_simplex(v, z=1):
return w
-def unif(n):
+def unif(n, type_as=None):
r"""
Return a uniform histogram of length `n` (simplex).
@@ -124,13 +124,19 @@ def unif(n):
----------
n : int
number of bins in the histogram
+ type_as : array_like
+ array of the same type of the expected output (numpy/pytorch/jax)
Returns
-------
- h : np.array (`n`,)
+ h : array_like (`n`,)
histogram of length `n` such that :math:`\forall i, \mathbf{h}_i = \frac{1}{n}`
"""
- return np.ones((n,)) / n
+ if type_as is None:
+ return np.ones((n,)) / n
+ else:
+ nx = get_backend(type_as)
+ return nx.ones((n,)) / n
def clean_zeros(a, b, M):
diff --git a/ot/weak.py b/ot/weak.py
new file mode 100644
index 0000000..f7d5b23
--- /dev/null
+++ b/ot/weak.py
@@ -0,0 +1,124 @@
+"""
+Weak optimal ransport solvers
+"""
+
+# Author: Remi Flamary <remi.flamary@polytehnique.edu>
+#
+# License: MIT License
+
+from .backend import get_backend
+from .optim import cg
+import numpy as np
+
+__all__ = ['weak_optimal_transport']
+
+
+def weak_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, G0=None, **kwargs):
+ r"""Solves the weak optimal transport problem between two empirical distributions
+
+
+ .. math::
+ \gamma = \mathop{\arg \min}_\gamma \quad \|X_a-diag(1/a)\gammaX_b\|_F^2
+
+ s.t. \ \gamma \mathbf{1} = \mathbf{a}
+
+ \gamma^T \mathbf{1} = \mathbf{b}
+
+ \gamma \geq 0
+
+ where :
+
+ - :math:`X_a` :math:`X_b` are the sample matrices.
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
+
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
+ Uses the conditional gradient algorithm to solve the problem proposed
+ in :ref:`[39] <references-weak>`.
+
+ Parameters
+ ----------
+ Xa : (ns,d) array-like, float
+ Source samples
+ Xb : (nt,d) array-like, float
+ Target samples
+ a : (ns,) array-like, float
+ Source histogram (uniform weight if empty list)
+ b : (nt,) array-like, float
+ Target histogram (uniform weight if empty list))
+ numItermax : int, optional
+ Max number of iterations
+ numItermaxEmd : int, optional
+ Max number of iterations for emd
+ stopThr : float, optional
+ Stop threshold on the relative variation (>0)
+ stopThr2 : float, optional
+ Stop threshold on the absolute variation (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma: array-like, shape (ns, nt)
+ Optimal transportation matrix for the given
+ parameters
+ log: dict, optional
+ If input log is true, a dictionary containing the
+ cost and dual variables and exit status
+
+
+ .. _references-weak:
+ References
+ ----------
+ .. [39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017).
+ Kantorovich duality for general transport costs and applications.
+ Journal of Functional Analysis, 273(11), 3327-3405.
+
+ See Also
+ --------
+ ot.bregman.sinkhorn : Entropic regularized OT
+ ot.optim.cg : General regularized OT
+ """
+
+ nx = get_backend(Xa, Xb)
+
+ Xa2 = nx.to_numpy(Xa)
+ Xb2 = nx.to_numpy(Xb)
+
+ if a is None:
+ a2 = np.ones((Xa.shape[0])) / Xa.shape[0]
+ else:
+ a2 = nx.to_numpy(a)
+ if b is None:
+ b2 = np.ones((Xb.shape[0])) / Xb.shape[0]
+ else:
+ b2 = nx.to_numpy(b)
+
+ # init uniform
+ if G0 is None:
+ T0 = a2[:, None] * b2[None, :]
+ else:
+ T0 = nx.to_numpy(G0)
+
+ # weak OT loss
+ def f(T):
+ return np.dot(a2, np.sum((Xa2 - np.dot(T, Xb2) / a2[:, None])**2, 1))
+
+ # weak OT gradient
+ def df(T):
+ return -2 * np.dot(Xa2 - np.dot(T, Xb2) / a2[:, None], Xb2.T)
+
+ # solve with conditional gradient and return solution
+ if log:
+ res, log = cg(a2, b2, 0, 1, f, df, T0, log=log, verbose=verbose, **kwargs)
+ log['u'] = nx.from_numpy(log['u'], type_as=Xa)
+ log['v'] = nx.from_numpy(log['v'], type_as=Xb)
+ return nx.from_numpy(res, type_as=Xa), log
+ else:
+ return nx.from_numpy(cg(a2, b2, 0, 1, f, df, T0, log=log, verbose=verbose, **kwargs), type_as=Xa)