summaryrefslogtreecommitdiff
path: root/ot/lp
diff options
context:
space:
mode:
authorNicolas Courty <ncourty@irisa.fr>2021-11-02 14:19:57 +0100
committerGitHub <noreply@github.com>2021-11-02 14:19:57 +0100
commit6775a527f9d3c801f8cdd805d8f205b6a75551b9 (patch)
treec0ed5a7c297b4003688fec52d46f918ea0086a7d /ot/lp
parenta335324d008e8982be61d7ace937815a2bfa98f9 (diff)
[MRG] Sliced and 1D Wasserstein distances : backend versions (#256)
* add numpy and torch backends * stat sets on functions * proper import * install recent torch on windows * install recent torch on windows * now testing all functions in backedn * add jax backedn * clenaup windowds * proper convert for jax backedn * pep8 * try again windows tests * test jax conversion * try proper widows tests * emd fuction ses backedn * better test partial OT * proper tests to_numpy and teplate Backend * pep8 * pep8 x2 * feaking sinkhorn works with torch * sinkhorn2 compatible * working ot.emd2 * important detach * it should work * jax autodiff emd * pep8 * no tast same for jax * new independat tests per backedn * freaking pep8 * add tests for gradients * deprecate ot.gpu * worging dist function * working dist * dist done in backedn * not in * remove indexing * change accuacy for jax * first pull backend * projection simplex * projection simplex * projection simplex * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 * add backedn discusion to quickstart guide * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 + better doc * proper links * corect doctest * big debug documentation * doctest again * doctest again bis * doctest again ter (last one or i kill myself) * backend test + doc proj simplex * correction test_utils * correction test_utils * correction cumsum * correction flip * correction flip v2 * more debug * more debug * more debug + pep8 * pep8 * argh * proj_simplex * backedn works for sort * proj simplex * jax sucks * update doc * Update test/test_utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update test/test_utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update ot/utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update ot/lp/__init__.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * begin comment alex * comment alex part 2 * optimize test gromov * proj_simplex on vectors * add awesome gradient decsnt example on the weights * pep98 of course * proof read example by alex * pep8 again * encoding oos in translation * correct legend * new backend functions for sliced * small indent pb * Optimized backendversion of sliced W * error in sliced W * after master merge * error sliced * error sliced * pep8 * test_sliced pep8 * doctest + precision for sliced * doctest * type win test_backend gather * type win test_backend gather * Update sliced.py change argument of padding pad_width * Update backend.py update redefinition * Update backend.py pep8 * Update backend.py pep 8 again.... * pep8 * build docs * emd2_1D example * refectoring emd_1d and variants * remove unused previous wasserstein_1d * pep8 * upate example * move stuff * tesys should work + implemù random backend * test random generayor functions * correction * better random generation * update sliced * update sliced * proper tests sliced * max sliced * chae file nam * add stuff * example sliced flow and barycenter * correct typo + update readme * exemple sliced flow done * pep8 * solver1d works * pep8 Co-authored-by: Rémi Flamary <remi.flamary@gmail.com> Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'ot/lp')
-rw-r--r--ot/lp/__init__.py367
-rw-r--r--ot/lp/solver_1d.py367
2 files changed, 401 insertions, 333 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 4e95ccf..2c18a88 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -13,20 +13,23 @@ import multiprocessing
import sys
import numpy as np
-from scipy.sparse import coo_matrix
import warnings
from . import cvx
from .cvx import barycenter
+
# import compiled emd
from .emd_wrap import emd_c, check_result, emd_1d_sorted
+from .solver_1d import emd_1d, emd2_1d, wasserstein_1d
+
from ..utils import dist, list_to_array
from ..utils import parmap
from ..backend import get_backend
-__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx',
+__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted',
'emd_1d', 'emd2_1d', 'wasserstein_1d']
+
def check_number_threads(numThreads):
"""Checks whether or not the requested number of threads has a valid value.
@@ -115,10 +118,10 @@ def estimate_dual_null_weights(alpha0, beta0, a, b, M):
.. warning::
This function is necessary because the C++ solver in emd_c
- discards all samples in the distributions with
- zeros weights. This means that while the primal variable (transport
+ discards all samples in the distributions with
+ zeros weights. This means that while the primal variable (transport
matrix) is exact, the solver only returns feasible dual potentials
- on the samples with weights different from zero.
+ on the samples with weights different from zero.
First we compute the constraints violations:
@@ -215,26 +218,26 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
format
.. note:: This function is backend-compatible and will work on arrays
- from all compatible backends.
+ from all compatible backends.
Uses the algorithm proposed in [1]_
Parameters
----------
- a : (ns,) array-like, float
+ a : (ns,) array-like, float
Source histogram (uniform weight if empty list)
- b : (nt,) array-like, float
- Target histogram (uniform weight if empty list)
- M : (ns,nt) array-like, float
- Loss matrix (c-order array in numpy with type float64)
- numItermax : int, optional (default=100000)
+ b : (nt,) array-like, float
+ Target histogram (uniform weight if empty list)
+ M : (ns,nt) array-like, float
+ Loss matrix (c-order array in numpy with type float64)
+ numItermax : int, optional (default=100000)
The maximum number of iterations before stopping the optimization
- algorithm if it has not converged.
- log: bool, optional (default=False)
- If True, returns a dictionary containing the cost and dual variables.
- Otherwise returns only the optimal transportation matrix.
+ algorithm if it has not converged.
+ log: bool, optional (default=False)
+ If True, returns a dictionary containing the cost and dual variables.
+ Otherwise returns only the optimal transportation matrix.
center_dual: boolean, optional (default=True)
- If True, centers the dual potential using function
+ If True, centers the dual potential using function
:ref:`center_ot_dual`.
numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
If compiled with OpenMP, chooses the number of threads to parallelize.
@@ -242,9 +245,9 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
Returns
-------
- gamma: array-like, shape (ns, nt)
+ gamma: array-like, shape (ns, nt)
Optimal transportation matrix for the given
- parameters
+ parameters
log: dict, optional
If input log is true, a dictionary containing the
cost and dual variables and exit status
@@ -277,10 +280,10 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
regularized OT"""
# convert to numpy if list
- a, b, M = list_to_array(a, b, M)
+ a, b, M = list_to_array(a, b, M)
a0, b0, M0 = a, b, M
- nx = get_backend(M0, a0, b0)
+ nx = get_backend(M0, a0, b0)
# convert to numpy
M = nx.to_numpy(M)
@@ -302,9 +305,9 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
"Dimension mismatch, check dimensions of M with a and b"
# ensure that same mass
- np.testing.assert_almost_equal(a.sum(0),
- b.sum(0), err_msg='a and b vector must have the same sum')
- b=b*a.sum()/b.sum()
+ np.testing.assert_almost_equal(a.sum(0),
+ b.sum(0), err_msg='a and b vector must have the same sum')
+ b = b * a.sum() / b.sum()
asel = a != 0
bsel = b != 0
@@ -415,10 +418,10 @@ def emd2(a, b, M, processes=1,
ot.bregman.sinkhorn : Entropic regularized OT
ot.optim.cg : General regularized OT"""
- a, b, M = list_to_array(a, b, M)
+ a, b, M = list_to_array(a, b, M)
a0, b0, M0 = a, b, M
- nx = get_backend(M0, a0, b0)
+ nx = get_backend(M0, a0, b0)
# convert to numpy
M = nx.to_numpy(M)
@@ -427,7 +430,7 @@ def emd2(a, b, M, processes=1,
a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64, order= 'C')
+ M = np.asarray(M, dtype=np.float64, order='C')
# if empty array given then use uniform distributions
if len(a) == 0:
@@ -463,8 +466,8 @@ def emd2(a, b, M, processes=1,
log['v'] = nx.from_numpy(v, type_as=b0)
log['warning'] = result_code_string
log['result_code'] = result_code
- cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0),
- (a0,b0, M0), (log['u'], log['v'], G))
+ cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0),
+ (a0, b0, M0), (log['u'], log['v'], G))
return [cost, log]
else:
def f(b):
@@ -479,8 +482,8 @@ def emd2(a, b, M, processes=1,
G = nx.from_numpy(G, type_as=M0)
cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0),
- (a0,b0, M0), (nx.from_numpy(u, type_as=a0),
- nx.from_numpy(v, type_as=b0),G))
+ (a0, b0, M0), (nx.from_numpy(u, type_as=a0),
+ nx.from_numpy(v, type_as=b0), G))
check_result(result_code)
return cost
@@ -603,305 +606,3 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
return X, log_dict
else:
return X
-
-
-def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
- log=False):
- r"""Solves the Earth Movers distance problem between 1d measures and returns
- the OT matrix
-
-
- .. math::
- \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j])
-
- s.t. \gamma 1 = a,
- \gamma^T 1= b,
- \gamma\geq 0
- where :
-
- - d is the metric
- - x_a and x_b are the samples
- - a and b are the sample weights
-
- When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`.
-
- Uses the algorithm detailed in [1]_
-
- Parameters
- ----------
- x_a : (ns,) or (ns, 1) ndarray, float64
- Source dirac locations (on the real line)
- x_b : (nt,) or (ns, 1) ndarray, float64
- Target dirac locations (on the real line)
- a : (ns,) ndarray, float64, optional
- Source histogram (default is uniform weight)
- b : (nt,) ndarray, float64, optional
- Target histogram (default is uniform weight)
- metric: str, optional (default='sqeuclidean')
- Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
- Due to implementation details, this function runs faster when
- `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used.
- p: float, optional (default=1.0)
- The p-norm to apply for if metric='minkowski'
- dense: boolean, optional (default=True)
- If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
- Otherwise returns a sparse representation using scipy's `coo_matrix`
- format. Due to implementation details, this function runs faster when
- `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
- are used.
- log: boolean, optional (default=False)
- If True, returns a dictionary containing the cost.
- Otherwise returns only the optimal transportation matrix.
-
- Returns
- -------
- gamma: (ns, nt) ndarray
- Optimal transportation matrix for the given parameters
- log: dict
- If input log is True, a dictionary containing the cost
-
-
- Examples
- --------
-
- Simple example with obvious solution. The function emd_1d accepts lists and
- performs automatic conversion to numpy arrays
-
- >>> import ot
- >>> a=[.5, .5]
- >>> b=[.5, .5]
- >>> x_a = [2., 0.]
- >>> x_b = [0., 3.]
- >>> ot.emd_1d(x_a, x_b, a, b)
- array([[0. , 0.5],
- [0.5, 0. ]])
- >>> ot.emd_1d(x_a, x_b)
- array([[0. , 0.5],
- [0.5, 0. ]])
-
- References
- ----------
-
- .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
- Transport", 2018.
-
- See Also
- --------
- ot.lp.emd : EMD for multidimensional distributions
- ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the
- transportation matrix)
- """
- a, b, x_a, x_b = list_to_array(a, b, x_a, x_b)
- nx = get_backend(x_a, x_b)
-
- assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \
- "emd_1d should only be used with monodimensional data"
- assert (x_b.ndim == 1 or x_b.ndim == 2 and x_b.shape[1] == 1), \
- "emd_1d should only be used with monodimensional data"
-
- # if empty array given then use uniform distributions
- if a is None or a.ndim == 0 or len(a) == 0:
- a = nx.ones((x_a.shape[0],), type_as=x_a) / x_a.shape[0]
- if b is None or b.ndim == 0 or len(b) == 0:
- b = nx.ones((x_b.shape[0],), type_as=x_b) / x_b.shape[0]
-
- # ensure that same mass
- np.testing.assert_almost_equal(
- nx.sum(a, axis=0),
- nx.sum(b, axis=0),
- err_msg='a and b vector must have the same sum'
- )
- b = b * nx.sum(a) / nx.sum(b)
-
- x_a_1d = nx.reshape(x_a, (-1,))
- x_b_1d = nx.reshape(x_b, (-1,))
- perm_a = nx.argsort(x_a_1d)
- perm_b = nx.argsort(x_b_1d)
-
- G_sorted, indices, cost = emd_1d_sorted(
- nx.to_numpy(a[perm_a]),
- nx.to_numpy(b[perm_b]),
- nx.to_numpy(x_a_1d[perm_a]),
- nx.to_numpy(x_b_1d[perm_b]),
- metric=metric, p=p
- )
-
- G = nx.coo_matrix(
- G_sorted,
- perm_a[indices[:, 0]],
- perm_b[indices[:, 1]],
- shape=(a.shape[0], b.shape[0]),
- type_as=x_a
- )
- if dense:
- G = nx.todense(G)
- elif str(nx) == "jax":
- warnings.warn("JAX does not support sparse matrices, converting to dense")
- if log:
- log = {'cost': cost}
- return G, log
- return G
-
-
-def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
- log=False):
- r"""Solves the Earth Movers distance problem between 1d measures and returns
- the loss
-
-
- .. math::
- \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j])
-
- s.t. \gamma 1 = a,
- \gamma^T 1= b,
- \gamma\geq 0
- where :
-
- - d is the metric
- - x_a and x_b are the samples
- - a and b are the sample weights
-
- When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`.
-
- Uses the algorithm detailed in [1]_
-
- Parameters
- ----------
- x_a : (ns,) or (ns, 1) ndarray, float64
- Source dirac locations (on the real line)
- x_b : (nt,) or (ns, 1) ndarray, float64
- Target dirac locations (on the real line)
- a : (ns,) ndarray, float64, optional
- Source histogram (default is uniform weight)
- b : (nt,) ndarray, float64, optional
- Target histogram (default is uniform weight)
- metric: str, optional (default='sqeuclidean')
- Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
- Due to implementation details, this function runs faster when
- `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
- are used.
- p: float, optional (default=1.0)
- The p-norm to apply for if metric='minkowski'
- dense: boolean, optional (default=True)
- If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
- Otherwise returns a sparse representation using scipy's `coo_matrix`
- format. Only used if log is set to True. Due to implementation details,
- this function runs faster when dense is set to False.
- log: boolean, optional (default=False)
- If True, returns a dictionary containing the transportation matrix.
- Otherwise returns only the loss.
-
- Returns
- -------
- loss: float
- Cost associated to the optimal transportation
- log: dict
- If input log is True, a dictionary containing the Optimal transportation
- matrix for the given parameters
-
-
- Examples
- --------
-
- Simple example with obvious solution. The function emd2_1d accepts lists and
- performs automatic conversion to numpy arrays
-
- >>> import ot
- >>> a=[.5, .5]
- >>> b=[.5, .5]
- >>> x_a = [2., 0.]
- >>> x_b = [0., 3.]
- >>> ot.emd2_1d(x_a, x_b, a, b)
- 0.5
- >>> ot.emd2_1d(x_a, x_b)
- 0.5
-
- References
- ----------
-
- .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
- Transport", 2018.
-
- See Also
- --------
- ot.lp.emd2 : EMD for multidimensional distributions
- ot.lp.emd_1d : EMD for 1d distributions (returns the transportation matrix
- instead of the cost)
- """
- # If we do not return G (log==False), then we should not to cast it to dense
- # (useless overhead)
- G, log_emd = emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric=metric, p=p,
- dense=dense and log, log=True)
- cost = log_emd['cost']
- if log:
- log_emd = {'G': G}
- return cost, log_emd
- return cost
-
-
-def wasserstein_1d(x_a, x_b, a=None, b=None, p=1.):
- r"""Solves the p-Wasserstein distance problem between 1d measures and returns
- the distance
-
- .. math::
- \min_\gamma \left( \sum_i \sum_j \gamma_{ij} \|x_a[i] - x_b[j]\|^p \right)^{1/p}
-
- s.t. \gamma 1 = a,
- \gamma^T 1= b,
- \gamma\geq 0
-
- where :
-
- - x_a and x_b are the samples
- - a and b are the sample weights
-
- Uses the algorithm detailed in [1]_
-
- Parameters
- ----------
- x_a : (ns,) or (ns, 1) ndarray, float64
- Source dirac locations (on the real line)
- x_b : (nt,) or (ns, 1) ndarray, float64
- Target dirac locations (on the real line)
- a : (ns,) ndarray, float64, optional
- Source histogram (default is uniform weight)
- b : (nt,) ndarray, float64, optional
- Target histogram (default is uniform weight)
- p: float, optional (default=1.0)
- The order of the p-Wasserstein distance to be computed
-
- Returns
- -------
- dist: float
- p-Wasserstein distance
-
-
- Examples
- --------
-
- Simple example with obvious solution. The function wasserstein_1d accepts
- lists and performs automatic conversion to numpy arrays
-
- >>> import ot
- >>> a=[.5, .5]
- >>> b=[.5, .5]
- >>> x_a = [2., 0.]
- >>> x_b = [0., 3.]
- >>> ot.wasserstein_1d(x_a, x_b, a, b)
- 0.5
- >>> ot.wasserstein_1d(x_a, x_b)
- 0.5
-
- References
- ----------
-
- .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
- Transport", 2018.
-
- See Also
- --------
- ot.lp.emd_1d : EMD for 1d distributions
- """
- cost_emd = emd2_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric='minkowski', p=p,
- dense=False, log=False)
- return np.power(cost_emd, 1. / p)
diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py
new file mode 100644
index 0000000..42554aa
--- /dev/null
+++ b/ot/lp/solver_1d.py
@@ -0,0 +1,367 @@
+# -*- coding: utf-8 -*-
+"""
+Exact solvers for the 1D Wasserstein distance using cvxopt
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+# Author: Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+import numpy as np
+import warnings
+
+from .emd_wrap import emd_1d_sorted
+from ..backend import get_backend
+from ..utils import list_to_array
+
+
+def quantile_function(qs, cws, xs):
+ r""" Computes the quantile function of an empirical distribution
+
+ Parameters
+ ----------
+ qs: array-like, shape (n,)
+ Quantiles at which the quantile function is evaluated
+ cws: array-like, shape (m, ...)
+ cumulative weights of the 1D empirical distribution, if batched, must be similar to xs
+ xs: array-like, shape (n, ...)
+ locations of the 1D empirical distribution, batched against the `xs.ndim - 1` first dimensions
+
+ Returns
+ -------
+ q: array-like, shape (..., n)
+ The quantiles of the distribution
+ """
+ nx = get_backend(qs, cws)
+ n = xs.shape[0]
+ if nx.__name__ == 'torch':
+ # this is to ensure the best performance for torch searchsorted
+ # and avoid a warninng related to non-contiguous arrays
+ cws = cws.T.contiguous()
+ qs = qs.T.contiguous()
+ else:
+ cws = cws.T
+ qs = qs.T
+ idx = nx.searchsorted(cws, qs).T
+ return nx.take_along_axis(xs, nx.clip(idx, 0, n - 1), axis=0)
+
+
+def wasserstein_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, require_sort=True):
+ r"""
+ Computes the 1 dimensional OT loss [15] between two (batched) empirical
+ distributions
+
+ .. math:
+ OT_{loss} = \int_0^1 |cdf_u^{-1}(q) cdf_v^{-1}(q)|^p dq
+
+ It is formally the p-Wasserstein distance raised to the power p.
+ We do so in a vectorized way by first building the individual quantile functions then integrating them.
+
+ This function should be preferred to `emd_1d` whenever the backend is
+ different to numpy, and when gradients over
+ either sample positions or weights are required.
+
+ Parameters
+ ----------
+ u_values: array-like, shape (n, ...)
+ locations of the first empirical distribution
+ v_values: array-like, shape (m, ...)
+ locations of the second empirical distribution
+ u_weights: array-like, shape (n, ...), optional
+ weights of the first empirical distribution, if None then uniform weights are used
+ v_weights: array-like, shape (m, ...), optional
+ weights of the second empirical distribution, if None then uniform weights are used
+ p: int, optional
+ order of the ground metric used, should be at least 1 (see [2, Chap. 2], default is 1
+ require_sort: bool, optional
+ sort the distributions atoms locations, if False we will consider they have been sorted prior to being passed to
+ the function, default is True
+
+ Returns
+ -------
+ cost: float/array-like, shape (...)
+ the batched EMD
+
+ References
+ ----------
+ .. [15] Peyré, G., & Cuturi, M. (2018). Computational Optimal Transport.
+
+ """
+
+ assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p)
+
+ if u_weights is not None and v_weights is not None:
+ nx = get_backend(u_values, v_values, u_weights, v_weights)
+ else:
+ nx = get_backend(u_values, v_values)
+
+ n = u_values.shape[0]
+ m = v_values.shape[0]
+
+ if u_weights is None:
+ u_weights = nx.full(u_values.shape, 1. / n)
+ elif u_weights.ndim != u_values.ndim:
+ u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1)
+ if v_weights is None:
+ v_weights = nx.full(v_values.shape, 1. / m)
+ elif v_weights.ndim != v_values.ndim:
+ v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1)
+
+ if require_sort:
+ u_sorter = nx.argsort(u_values, 0)
+ u_values = nx.take_along_axis(u_values, u_sorter, 0)
+
+ v_sorter = nx.argsort(v_values, 0)
+ v_values = nx.take_along_axis(v_values, v_sorter, 0)
+
+ u_weights = nx.take_along_axis(u_weights, u_sorter, 0)
+ v_weights = nx.take_along_axis(v_weights, v_sorter, 0)
+
+ u_cumweights = nx.cumsum(u_weights, 0)
+ v_cumweights = nx.cumsum(v_weights, 0)
+
+ qs = nx.sort(nx.concatenate((u_cumweights, v_cumweights), 0), 0)
+ u_quantiles = quantile_function(qs, u_cumweights, u_values)
+ v_quantiles = quantile_function(qs, v_cumweights, v_values)
+ qs = nx.zero_pad(qs, pad_width=[(1, 0)] + (qs.ndim - 1) * [(0, 0)])
+ delta = qs[1:, ...] - qs[:-1, ...]
+ diff_quantiles = nx.abs(u_quantiles - v_quantiles)
+
+ if p == 1:
+ return nx.sum(delta * nx.abs(diff_quantiles), axis=0)
+ return nx.sum(delta * nx.power(diff_quantiles, p), axis=0)
+
+
+def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
+ log=False):
+ r"""Solves the Earth Movers distance problem between 1d measures and returns
+ the OT matrix
+
+
+ .. math::
+ \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j])
+
+ s.t. \gamma 1 = a,
+ \gamma^T 1= b,
+ \gamma\geq 0
+ where :
+
+ - d is the metric
+ - x_a and x_b are the samples
+ - a and b are the sample weights
+
+ When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`.
+
+ Uses the algorithm detailed in [1]_
+
+ Parameters
+ ----------
+ x_a : (ns,) or (ns, 1) ndarray, float64
+ Source dirac locations (on the real line)
+ x_b : (nt,) or (ns, 1) ndarray, float64
+ Target dirac locations (on the real line)
+ a : (ns,) ndarray, float64, optional
+ Source histogram (default is uniform weight)
+ b : (nt,) ndarray, float64, optional
+ Target histogram (default is uniform weight)
+ metric: str, optional (default='sqeuclidean')
+ Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
+ Due to implementation details, this function runs faster when
+ `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used.
+ p: float, optional (default=1.0)
+ The p-norm to apply for if metric='minkowski'
+ dense: boolean, optional (default=True)
+ If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
+ Otherwise returns a sparse representation using scipy's `coo_matrix`
+ format. Due to implementation details, this function runs faster when
+ `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
+ are used.
+ log: boolean, optional (default=False)
+ If True, returns a dictionary containing the cost.
+ Otherwise returns only the optimal transportation matrix.
+
+ Returns
+ -------
+ gamma: (ns, nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log: dict
+ If input log is True, a dictionary containing the cost
+
+
+ Examples
+ --------
+
+ Simple example with obvious solution. The function emd_1d accepts lists and
+ performs automatic conversion to numpy arrays
+
+ >>> import ot
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> x_a = [2., 0.]
+ >>> x_b = [0., 3.]
+ >>> ot.emd_1d(x_a, x_b, a, b)
+ array([[0. , 0.5],
+ [0.5, 0. ]])
+ >>> ot.emd_1d(x_a, x_b)
+ array([[0. , 0.5],
+ [0.5, 0. ]])
+
+ References
+ ----------
+
+ .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
+ Transport", 2018.
+
+ See Also
+ --------
+ ot.lp.emd : EMD for multidimensional distributions
+ ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the
+ transportation matrix)
+ """
+ a, b, x_a, x_b = list_to_array(a, b, x_a, x_b)
+ nx = get_backend(x_a, x_b)
+
+ assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \
+ "emd_1d should only be used with monodimensional data"
+ assert (x_b.ndim == 1 or x_b.ndim == 2 and x_b.shape[1] == 1), \
+ "emd_1d should only be used with monodimensional data"
+
+ # if empty array given then use uniform distributions
+ if a is None or a.ndim == 0 or len(a) == 0:
+ a = nx.ones((x_a.shape[0],), type_as=x_a) / x_a.shape[0]
+ if b is None or b.ndim == 0 or len(b) == 0:
+ b = nx.ones((x_b.shape[0],), type_as=x_b) / x_b.shape[0]
+
+ # ensure that same mass
+ np.testing.assert_almost_equal(
+ nx.sum(a, axis=0),
+ nx.sum(b, axis=0),
+ err_msg='a and b vector must have the same sum'
+ )
+ b = b * nx.sum(a) / nx.sum(b)
+
+ x_a_1d = nx.reshape(x_a, (-1,))
+ x_b_1d = nx.reshape(x_b, (-1,))
+ perm_a = nx.argsort(x_a_1d)
+ perm_b = nx.argsort(x_b_1d)
+
+ G_sorted, indices, cost = emd_1d_sorted(
+ nx.to_numpy(a[perm_a]),
+ nx.to_numpy(b[perm_b]),
+ nx.to_numpy(x_a_1d[perm_a]),
+ nx.to_numpy(x_b_1d[perm_b]),
+ metric=metric, p=p
+ )
+
+ G = nx.coo_matrix(
+ G_sorted,
+ perm_a[indices[:, 0]],
+ perm_b[indices[:, 1]],
+ shape=(a.shape[0], b.shape[0]),
+ type_as=x_a
+ )
+ if dense:
+ G = nx.todense(G)
+ elif str(nx) == "jax":
+ warnings.warn("JAX does not support sparse matrices, converting to dense")
+ if log:
+ log = {'cost': cost}
+ return G, log
+ return G
+
+
+def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
+ log=False):
+ r"""Solves the Earth Movers distance problem between 1d measures and returns
+ the loss
+
+
+ .. math::
+ \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j])
+
+ s.t. \gamma 1 = a,
+ \gamma^T 1= b,
+ \gamma\geq 0
+ where :
+
+ - d is the metric
+ - x_a and x_b are the samples
+ - a and b are the sample weights
+
+ When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`.
+
+ Uses the algorithm detailed in [1]_
+
+ Parameters
+ ----------
+ x_a : (ns,) or (ns, 1) ndarray, float64
+ Source dirac locations (on the real line)
+ x_b : (nt,) or (ns, 1) ndarray, float64
+ Target dirac locations (on the real line)
+ a : (ns,) ndarray, float64, optional
+ Source histogram (default is uniform weight)
+ b : (nt,) ndarray, float64, optional
+ Target histogram (default is uniform weight)
+ metric: str, optional (default='sqeuclidean')
+ Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
+ Due to implementation details, this function runs faster when
+ `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
+ are used.
+ p: float, optional (default=1.0)
+ The p-norm to apply for if metric='minkowski'
+ dense: boolean, optional (default=True)
+ If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
+ Otherwise returns a sparse representation using scipy's `coo_matrix`
+ format. Only used if log is set to True. Due to implementation details,
+ this function runs faster when dense is set to False.
+ log: boolean, optional (default=False)
+ If True, returns a dictionary containing the transportation matrix.
+ Otherwise returns only the loss.
+
+ Returns
+ -------
+ loss: float
+ Cost associated to the optimal transportation
+ log: dict
+ If input log is True, a dictionary containing the Optimal transportation
+ matrix for the given parameters
+
+
+ Examples
+ --------
+
+ Simple example with obvious solution. The function emd2_1d accepts lists and
+ performs automatic conversion to numpy arrays
+
+ >>> import ot
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> x_a = [2., 0.]
+ >>> x_b = [0., 3.]
+ >>> ot.emd2_1d(x_a, x_b, a, b)
+ 0.5
+ >>> ot.emd2_1d(x_a, x_b)
+ 0.5
+
+ References
+ ----------
+
+ .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
+ Transport", 2018.
+
+ See Also
+ --------
+ ot.lp.emd2 : EMD for multidimensional distributions
+ ot.lp.emd_1d : EMD for 1d distributions (returns the transportation matrix
+ instead of the cost)
+ """
+ # If we do not return G (log==False), then we should not to cast it to dense
+ # (useless overhead)
+ G, log_emd = emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric=metric, p=p,
+ dense=dense and log, log=True)
+ cost = log_emd['cost']
+ if log:
+ log_emd = {'G': G}
+ return cost, log_emd
+ return cost