summaryrefslogtreecommitdiff
path: root/ot/lp
diff options
context:
space:
mode:
Diffstat (limited to 'ot/lp')
-rw-r--r--ot/lp/EMD.h5
-rw-r--r--ot/lp/EMD_wrapper.cpp40
-rw-r--r--ot/lp/__init__.py161
-rw-r--r--ot/lp/cvx.py2
-rw-r--r--ot/lp/emd_wrap.pyx9
-rw-r--r--ot/lp/network_simplex_simple.h12
-rw-r--r--ot/lp/network_simplex_simple_omp.h20
-rw-r--r--ot/lp/solver_1d.py629
8 files changed, 809 insertions, 69 deletions
diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h
index 8a1f9ac..b56f060 100644
--- a/ot/lp/EMD.h
+++ b/ot/lp/EMD.h
@@ -18,6 +18,7 @@
#include <iostream>
#include <vector>
+#include <cstdint>
typedef unsigned int node_id_type;
@@ -28,8 +29,8 @@ enum ProblemType {
MAX_ITER_REACHED
};
-int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter);
-int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter, int numThreads);
+int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter);
+int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads);
diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp
index 2bdc172..4aa5a6e 100644
--- a/ot/lp/EMD_wrapper.cpp
+++ b/ot/lp/EMD_wrapper.cpp
@@ -20,11 +20,11 @@
int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
- double* alpha, double* beta, double *cost, int maxIter) {
+ double* alpha, double* beta, double *cost, uint64_t maxIter) {
// beware M and C are stored in row major C style!!!
using namespace lemon;
- int n, m, cur;
+ uint64_t n, m, cur;
typedef FullBipartiteDigraph Digraph;
DIGRAPH_TYPEDEFS(Digraph);
@@ -51,15 +51,15 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
// Define the graph
- std::vector<int> indI(n), indJ(m);
+ std::vector<uint64_t> indI(n), indJ(m);
std::vector<double> weights1(n), weights2(m);
Digraph di(n, m);
- NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter);
+ NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, (int) (n + m), n * m, maxIter);
// Set supply and demand, don't account for 0 values (faster)
cur=0;
- for (int i=0; i<n1; i++) {
+ for (uint64_t i=0; i<n1; i++) {
double val=*(X+i);
if (val>0) {
weights1[ cur ] = val;
@@ -70,7 +70,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
// Demand is actually negative supply...
cur=0;
- for (int i=0; i<n2; i++) {
+ for (uint64_t i=0; i<n2; i++) {
double val=*(Y+i);
if (val>0) {
weights2[ cur ] = -val;
@@ -79,12 +79,12 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
}
- net.supplyMap(&weights1[0], n, &weights2[0], m);
+ net.supplyMap(&weights1[0], (int) n, &weights2[0], (int) m);
// Set the cost of each edge
int64_t idarc = 0;
- for (int i=0; i<n; i++) {
- for (int j=0; j<m; j++) {
+ for (uint64_t i=0; i<n; i++) {
+ for (uint64_t j=0; j<m; j++) {
double val=*(D+indI[i]*n2+indJ[j]);
net.setCost(di.arcFromId(idarc), val);
++idarc;
@@ -95,7 +95,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
// Solve the problem with the network simplex algorithm
int ret=net.run();
- int i, j;
+ uint64_t i, j;
if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
*cost = 0;
Arc a; di.first(a);
@@ -122,11 +122,11 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
- double* alpha, double* beta, double *cost, int maxIter, int numThreads) {
+ double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) {
// beware M and C are stored in row major C style!!!
using namespace lemon_omp;
- int n, m, cur;
+ uint64_t n, m, cur;
typedef FullBipartiteDigraph Digraph;
DIGRAPH_TYPEDEFS(Digraph);
@@ -153,15 +153,15 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
// Define the graph
- std::vector<int> indI(n), indJ(m);
+ std::vector<uint64_t> indI(n), indJ(m);
std::vector<double> weights1(n), weights2(m);
Digraph di(n, m);
- NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter, numThreads);
+ NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, (int) (n + m), n * m, maxIter, numThreads);
// Set supply and demand, don't account for 0 values (faster)
cur=0;
- for (int i=0; i<n1; i++) {
+ for (uint64_t i=0; i<n1; i++) {
double val=*(X+i);
if (val>0) {
weights1[ cur ] = val;
@@ -172,7 +172,7 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
// Demand is actually negative supply...
cur=0;
- for (int i=0; i<n2; i++) {
+ for (uint64_t i=0; i<n2; i++) {
double val=*(Y+i);
if (val>0) {
weights2[ cur ] = -val;
@@ -181,12 +181,12 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
}
- net.supplyMap(&weights1[0], n, &weights2[0], m);
+ net.supplyMap(&weights1[0], (int) n, &weights2[0], (int) m);
// Set the cost of each edge
int64_t idarc = 0;
- for (int i=0; i<n; i++) {
- for (int j=0; j<m; j++) {
+ for (uint64_t i=0; i<n; i++) {
+ for (uint64_t j=0; j<m; j++) {
double val=*(D+indI[i]*n2+indJ[j]);
net.setCost(di.arcFromId(idarc), val);
++idarc;
@@ -197,7 +197,7 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
// Solve the problem with the network simplex algorithm
int ret=net.run();
- int i, j;
+ uint64_t i, j;
if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
*cost = 0;
Arc a; di.first(a);
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 390c32d..2ff02ab 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
"""
-Solvers for the original linear program OT problem
+Solvers for the original linear program OT problem.
"""
@@ -20,16 +20,17 @@ from .cvx import barycenter
# import compiled emd
from .emd_wrap import emd_c, check_result, emd_1d_sorted
-from .solver_1d import emd_1d, emd2_1d, wasserstein_1d
+from .solver_1d import (emd_1d, emd2_1d, wasserstein_1d,
+ binary_search_circle, wasserstein_circle,
+ semidiscrete_wasserstein2_unif_circle)
from ..utils import dist, list_to_array
from ..utils import parmap
from ..backend import get_backend
-
-
__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted',
- 'emd_1d', 'emd2_1d', 'wasserstein_1d']
+ 'emd_1d', 'emd2_1d', 'wasserstein_1d', 'generalized_free_support_barycenter',
+ 'binary_search_circle', 'wasserstein_circle', 'semidiscrete_wasserstein2_unif_circle']
def check_number_threads(numThreads):
@@ -232,6 +233,8 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
If this behaviour is unwanted, please make sure to provide a
floating point input.
+ .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value.
+
Uses the algorithm proposed in :ref:`[1] <references-emd>`.
Parameters
@@ -391,6 +394,8 @@ def emd2(a, b, M, processes=1,
If this behaviour is unwanted, please make sure to provide a
floating point input.
+ .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value.
+
Uses the algorithm proposed in :ref:`[1] <references-emd2>`.
Parameters
@@ -483,6 +488,11 @@ def emd2(a, b, M, processes=1,
assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \
"Dimension mismatch, check dimensions of M with a and b"
+ # ensure that same mass
+ np.testing.assert_almost_equal(a.sum(0),
+ b.sum(0,keepdims=True), err_msg='a and b vector must have the same sum')
+ b = b * a.sum(0) / b.sum(0,keepdims=True)
+
asel = a != 0
numThreads = check_number_threads(numThreads)
@@ -517,8 +527,8 @@ def emd2(a, b, M, processes=1,
log['warning'] = result_code_string
log['result_code'] = result_code
cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as),
- (a0, b0, M0), (log['u'] - nx.mean(log['u']),
- log['v'] - nx.mean(log['v']), G))
+ (a0, b0, M0), (log['u'] - nx.mean(log['u']),
+ log['v'] - nx.mean(log['v']), G))
return [cost, log]
else:
def f(b):
@@ -572,18 +582,18 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
where :
- :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one
- - the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}` are the empirical measures weights and sum to one for each :math:`i`
- - the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}` are the empirical measures atoms locations
+ - `measure_weights` denotes the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}`: empirical measures weights (on simplex)
+ - `measures_locations` denotes the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}`: empirical measures atoms locations
- :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter
- This problem is considered in :ref:`[1] <references-free-support-barycenter>` (Algorithm 2).
+ This problem is considered in :ref:`[20] <references-free-support-barycenter>` (Algorithm 2).
There are two differences with the following codes:
- we do not optimize over the weights
- we do not do line search for the locations updates, we use i.e. :math:`\theta = 1` in
- :ref:`[1] <references-free-support-barycenter>` (Algorithm 2). This can be seen as a discrete
+ :ref:`[20] <references-free-support-barycenter>` (Algorithm 2). This can be seen as a discrete
implementation of the fixed-point algorithm of
- :ref:`[2] <references-free-support-barycenter>` proposed in the continuous setting.
+ :ref:`[43] <references-free-support-barycenter>` proposed in the continuous setting.
Parameters
----------
@@ -623,13 +633,13 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
.. _references-free-support-barycenter:
References
----------
- .. [1] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
+ .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
- .. [2] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
+ .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
"""
- nx = get_backend(*measures_locations,*measures_weights,X_init)
+ nx = get_backend(*measures_locations, *measures_weights, X_init)
iter_count = 0
@@ -637,9 +647,9 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
k = X_init.shape[0]
d = X_init.shape[1]
if b is None:
- b = nx.ones((k,),type_as=X_init) / k
+ b = nx.ones((k,), type_as=X_init) / k
if weights is None:
- weights = nx.ones((N,),type_as=X_init) / N
+ weights = nx.ones((N,), type_as=X_init) / N
X = X_init
@@ -650,15 +660,14 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
while (displacement_square_norm > stopThr and iter_count < numItermax):
- T_sum = nx.zeros((k, d),type_as=X_init)
-
+ T_sum = nx.zeros((k, d), type_as=X_init)
- for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights):
+ for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights):
M_i = dist(X, measure_locations_i)
T_i = emd(b, measure_weights_i, M_i, numThreads=numThreads)
- T_sum = T_sum + weight_i * 1. / b[:,None] * nx.dot(T_i, measure_locations_i)
+ T_sum = T_sum + weight_i * 1. / b[:, None] * nx.dot(T_i, measure_locations_i)
- displacement_square_norm = nx.sum((T_sum - X)**2)
+ displacement_square_norm = nx.sum((T_sum - X) ** 2)
if log:
displacement_square_norms.append(displacement_square_norm)
@@ -675,3 +684,111 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
else:
return X
+
+def generalized_free_support_barycenter(X_list, a_list, P_list, n_samples_bary, Y_init=None, b=None, weights=None,
+ numItermax=100, stopThr=1e-7, verbose=False, log=None, numThreads=1, eps=0):
+ r"""
+ Solves the free support generalised Wasserstein barycenter problem: finding a barycenter (a discrete measure with
+ a fixed amount of points of uniform weights) whose respective projections fit the input measures.
+ More formally:
+
+ .. math::
+ \min_\gamma \quad \sum_{i=1}^p w_i W_2^2(\nu_i, \mathbf{P}_i\#\gamma)
+
+ where :
+
+ - :math:`\gamma = \sum_{l=1}^n b_l\delta_{y_l}` is the desired barycenter with each :math:`y_l \in \mathbb{R}^d`
+ - :math:`\mathbf{b} \in \mathbb{R}^{n}` is the desired weights vector of the barycenter
+ - The input measures are :math:`\nu_i = \sum_{j=1}^{k_i}a_{i,j}\delta_{x_{i,j}}`
+ - The :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}` are the respective empirical measures weights (on the simplex)
+ - The :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d_i}` are the respective empirical measures atoms locations
+ - :math:`w = (w_1, \cdots w_p)` are the barycenter coefficients (on the simplex)
+ - Each :math:`\mathbf{P}_i \in \mathbb{R}^{d, d_i}`, and :math:`P_i\#\nu_i = \sum_{j=1}^{k_i}a_{i,j}\delta_{P_ix_{i,j}}`
+
+ As show by :ref:`[42] <references-generalized-free-support-barycenter>`,
+ this problem can be re-written as a Wasserstein Barycenter problem,
+ which we solve using the free support method :ref:`[20] <references-generalized-free-support-barycenter>`
+ (Algorithm 2).
+
+ Parameters
+ ----------
+ X_list : list of p (k_i,d_i) array-like
+ Discrete supports of the input measures: each consists of :math:`k_i` locations of a `d_i`-dimensional space
+ (:math:`k_i` can be different for each element of the list)
+ a_list : list of p (k_i,) array-like
+ Measure weights: each element is a vector (k_i) on the simplex
+ P_list : list of p (d_i,d) array-like
+ Each :math:`P_i` is a linear map :math:`\mathbb{R}^{d} \rightarrow \mathbb{R}^{d_i}`
+ n_samples_bary : int
+ Number of barycenter points
+ Y_init : (n_samples_bary,d) array-like
+ Initialization of the support locations (on `k` atoms) of the barycenter
+ b : (n_samples_bary,) array-like
+ Initialization of the weights of the barycenter measure (on the simplex)
+ weights : (p,) array-like
+ Initialization of the coefficients of the barycenter (on the simplex)
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
+ If compiled with OpenMP, chooses the number of threads to parallelize.
+ "max" selects the highest number possible.
+ eps: Stability coefficient for the change of variable matrix inversion
+ If the :math:`\mathbf{P}_i^T` matrices don't span :math:`\mathbb{R}^d`, the problem is ill-defined and a matrix
+ inversion will fail. In this case one may set eps=1e-8 and get a solution anyway (which may make little sense)
+
+
+ Returns
+ -------
+ Y : (n_samples_bary,d) array-like
+ Support locations (on n_samples_bary atoms) of the barycenter
+
+
+ .. _references-generalized-free-support-barycenter:
+ References
+ ----------
+ .. [20] Cuturi, M. and Doucet, A.. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
+
+ .. [42] Delon, J., Gozlan, N., and Saint-Dizier, A.. Generalized Wasserstein barycenters between probability measures living on different subspaces. arXiv preprint arXiv:2105.09755, 2021.
+
+ """
+ nx = get_backend(*X_list, *a_list, *P_list)
+ d = P_list[0].shape[1]
+ p = len(X_list)
+
+ if weights is None:
+ weights = nx.ones(p, type_as=X_list[0]) / p
+
+ # variable change matrix to reduce the problem to a Wasserstein Barycenter (WB)
+ A = eps * nx.eye(d, type_as=X_list[0]) # if eps nonzero: will force the invertibility of A
+ for (P_i, lambda_i) in zip(P_list, weights):
+ A = A + lambda_i * P_i.T @ P_i
+ B = nx.inv(nx.sqrtm(A))
+
+ Z_list = [x @ Pi @ B.T for (x, Pi) in zip(X_list, P_list)] # change of variables -> (WB) problem on Z
+
+ if Y_init is None:
+ Y_init = nx.randn(n_samples_bary, d, type_as=X_list[0])
+
+ if b is None:
+ b = nx.ones(n_samples_bary, type_as=X_list[0]) / n_samples_bary # not optimised
+
+ out = free_support_barycenter(Z_list, a_list, Y_init, b, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log, numThreads=numThreads)
+
+ if log: # unpack
+ Y, log_dict = out
+ else:
+ Y = out
+ log_dict = None
+ Y = Y @ B.T # return to the Generalised WB formulation
+
+ if log:
+ return Y, log_dict
+ else:
+ return Y
diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py
index fbf3c0e..361ad0f 100644
--- a/ot/lp/cvx.py
+++ b/ot/lp/cvx.py
@@ -80,7 +80,7 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po
if weights is None:
weights = np.ones(A.shape[1]) / A.shape[1]
else:
- assert(len(weights) == A.shape[1])
+ assert len(weights) == A.shape[1]
n_distributions = A.shape[1]
n = A.shape[0]
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index 42e08f4..e5cec89 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -14,13 +14,14 @@ from ..utils import dist
cimport cython
cimport libc.math as math
+from libc.stdint cimport uint64_t
import warnings
cdef extern from "EMD.h":
- int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter) nogil
- int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter, int numThreads) nogil
+ int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter) nogil
+ int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) nogil
cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED
@@ -39,7 +40,7 @@ def check_result(result_code):
@cython.boundscheck(False)
@cython.wraparound(False)
-def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter, int numThreads):
+def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, uint64_t max_iter, int numThreads):
"""
Solves the Earth Movers distance problem and returns the optimal transport matrix
@@ -75,7 +76,7 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
target histogram
M : (ns,nt) numpy.ndarray, float64
loss matrix
- max_iter : int
+ max_iter : uint64_t
The maximum number of iterations before stopping the optimization
algorithm if it has not converged.
diff --git a/ot/lp/network_simplex_simple.h b/ot/lp/network_simplex_simple.h
index 3b46b9b..9612a8a 100644
--- a/ot/lp/network_simplex_simple.h
+++ b/ot/lp/network_simplex_simple.h
@@ -233,7 +233,7 @@ namespace lemon {
/// mixed order in the internal data structure.
/// In special cases, it could lead to better overall performance,
/// but it is usually slower. Therefore it is disabled by default.
- NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, size_t maxiters) :
+ NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, uint64_t maxiters) :
_graph(graph), //_arc_id(graph),
_arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs),
MAX(std::numeric_limits<Value>::max()),
@@ -242,7 +242,7 @@ namespace lemon {
{
// Reset data structures
reset();
- max_iter=maxiters;
+ max_iter = maxiters;
}
/// The type of the flow amounts, capacity bounds and supply values
@@ -293,7 +293,7 @@ namespace lemon {
private:
- size_t max_iter;
+ uint64_t max_iter;
TEMPLATE_DIGRAPH_TYPEDEFS(GR);
typedef std::vector<int> IntVector;
@@ -1427,14 +1427,12 @@ namespace lemon {
// Perform heuristic initial pivots
if (!initialPivots()) return UNBOUNDED;
- size_t iter_number=0;
+ uint64_t iter_number = 0;
//pivot.setDantzig(true);
// Execute the Network Simplex algorithm
while (pivot.findEnteringArc()) {
if(max_iter > 0 && ++iter_number>=max_iter&&max_iter>0){
- char errMess[1000];
- sprintf( errMess, "RESULT MIGHT BE INACURATE\nMax number of iteration reached, currently \%d. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher\n",iter_number );
- std::cerr << errMess;
+ // max iterations hit
retVal = MAX_ITER_REACHED;
break;
}
diff --git a/ot/lp/network_simplex_simple_omp.h b/ot/lp/network_simplex_simple_omp.h
index 87e4c05..890b7ab 100644
--- a/ot/lp/network_simplex_simple_omp.h
+++ b/ot/lp/network_simplex_simple_omp.h
@@ -41,8 +41,8 @@
#undef EPSILON
#undef _EPSILON
#undef MAX_DEBUG_ITER
-#define EPSILON std::numeric_limits<Cost>::epsilon()*10
-#define _EPSILON 1e-8
+#define EPSILON std::numeric_limits<Cost>::epsilon()
+#define _EPSILON 1e-14
#define MAX_DEBUG_ITER 100000
/// \ingroup min_cost_flow_algs
@@ -67,7 +67,7 @@
//#include "core.h"
//#include "lmath.h"
-#ifdef OMP
+#ifdef _OPENMP
#include <omp.h>
#endif
#include <cmath>
@@ -244,7 +244,7 @@ namespace lemon_omp {
/// mixed order in the internal data structure.
/// In special cases, it could lead to better overall performance,
/// but it is usually slower. Therefore it is disabled by default.
- NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, size_t maxiters = 0, int numThreads=-1) :
+ NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, uint64_t maxiters = 0, int numThreads=-1) :
_graph(graph), //_arc_id(graph),
_arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs),
MAX(std::numeric_limits<Value>::max()),
@@ -254,7 +254,7 @@ namespace lemon_omp {
// Reset data structures
reset();
max_iter = maxiters;
-#ifdef OMP
+#ifdef _OPENMP
if (max_threads < 0) {
max_threads = omp_get_max_threads();
}
@@ -317,7 +317,7 @@ namespace lemon_omp {
private:
- size_t max_iter;
+ uint64_t max_iter;
int num_threads;
TEMPLATE_DIGRAPH_TYPEDEFS(GR);
@@ -513,7 +513,7 @@ namespace lemon_omp {
int j;
#pragma omp parallel
{
-#ifdef OMP
+#ifdef _OPENMP
int t = omp_get_thread_num();
#else
int t = 0;
@@ -1563,7 +1563,7 @@ namespace lemon_omp {
// Perform heuristic initial pivots
if (!initialPivots()) return UNBOUNDED;
- size_t iter_number = 0;
+ uint64_t iter_number = 0;
// Execute the Network Simplex algorithm
while (pivot.findEnteringArc()) {
if ((++iter_number <= max_iter&&max_iter > 0) || max_iter<=0) {
@@ -1610,9 +1610,7 @@ namespace lemon_omp {
} else {
- char errMess[1000];
- sprintf( errMess, "RESULT MIGHT BE INACURATE\nMax number of iteration reached, currently \%d. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher\n",iter_number );
- std::cerr << errMess;
+ // max iters
retVal = MAX_ITER_REACHED;
break;
}
diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py
index 43763a9..bcfc920 100644
--- a/ot/lp/solver_1d.py
+++ b/ot/lp/solver_1d.py
@@ -53,7 +53,7 @@ def wasserstein_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, requ
distributions
.. math:
- OT_{loss} = \int_0^1 |cdf_u^{-1}(q) cdf_v^{-1}(q)|^p dq
+ OT_{loss} = \int_0^1 |cdf_u^{-1}(q) - cdf_v^{-1}(q)|^p dq
It is formally the p-Wasserstein distance raised to the power p.
We do so in a vectorized way by first building the individual quantile functions then integrating them.
@@ -129,7 +129,7 @@ def wasserstein_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, requ
diff_quantiles = nx.abs(u_quantiles - v_quantiles)
if p == 1:
- return nx.sum(delta * nx.abs(diff_quantiles), axis=0)
+ return nx.sum(delta * diff_quantiles, axis=0)
return nx.sum(delta * nx.power(diff_quantiles, p), axis=0)
@@ -365,3 +365,628 @@ def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
log_emd = {'G': G}
return cost, log_emd
return cost
+
+
+def roll_cols(M, shifts):
+ r"""
+ Utils functions which allow to shift the order of each row of a 2d matrix
+
+ Parameters
+ ----------
+ M : (nr, nc) ndarray
+ Matrix to shift
+ shifts: int or (nr,) ndarray
+
+ Returns
+ -------
+ Shifted array
+
+ Examples
+ --------
+ >>> M = np.array([[1,2,3],[4,5,6],[7,8,9]])
+ >>> roll_cols(M, 2)
+ array([[2, 3, 1],
+ [5, 6, 4],
+ [8, 9, 7]])
+ >>> roll_cols(M, np.array([[1],[2],[1]]))
+ array([[3, 1, 2],
+ [5, 6, 4],
+ [9, 7, 8]])
+
+ References
+ ----------
+ https://stackoverflow.com/questions/66596699/how-to-shift-columns-or-rows-in-a-tensor-with-different-offsets-in-pytorch
+ """
+ nx = get_backend(M)
+
+ n_rows, n_cols = M.shape
+
+ arange1 = nx.tile(nx.reshape(nx.arange(n_cols), (1, n_cols)), (n_rows, 1))
+ arange2 = (arange1 - shifts) % n_cols
+
+ return nx.take_along_axis(M, arange2, 1)
+
+
+def derivative_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p=2):
+ r""" Computes the left and right derivative of the cost (Equation (6.3) and (6.4) of [1])
+
+ Parameters
+ ----------
+ theta: array-like, shape (n_batch, n)
+ Cuts on the circle
+ u_values: array-like, shape (n_batch, n)
+ locations of the first empirical distribution
+ v_values: array-like, shape (n_batch, n)
+ locations of the second empirical distribution
+ u_cdf: array-like, shape (n_batch, n)
+ cdf of the first empirical distribution
+ v_cdf: array-like, shape (n_batch, n)
+ cdf of the second empirical distribution
+ p: float, optional = 2
+ Power p used for computing the Wasserstein distance
+
+ Returns
+ -------
+ dCp: array-like, shape (n_batch, 1)
+ The batched right derivative
+ dCm: array-like, shape (n_batch, 1)
+ The batched left derivative
+
+ References
+ ---------
+ .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258.
+ """
+ nx = get_backend(theta, u_values, v_values, u_cdf, v_cdf)
+
+ v_values = nx.copy(v_values)
+
+ n = u_values.shape[-1]
+ m_batch, m = v_values.shape
+
+ v_cdf_theta = v_cdf - (theta - nx.floor(theta))
+
+ mask_p = v_cdf_theta >= 0
+ mask_n = v_cdf_theta < 0
+
+ v_values[mask_n] += nx.floor(theta)[mask_n] + 1
+ v_values[mask_p] += nx.floor(theta)[mask_p]
+
+ if nx.any(mask_n) and nx.any(mask_p):
+ v_cdf_theta[mask_n] += 1
+
+ v_cdf_theta2 = nx.copy(v_cdf_theta)
+ v_cdf_theta2[mask_n] = np.inf
+ shift = (-nx.argmin(v_cdf_theta2, axis=-1))
+
+ v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1, 1)))
+ v_values = roll_cols(v_values, nx.reshape(shift, (-1, 1)))
+ v_values = nx.concatenate([v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1)
+
+ if nx.__name__ == 'torch':
+ # this is to ensure the best performance for torch searchsorted
+ # and avoid a warninng related to non-contiguous arrays
+ u_cdf = u_cdf.contiguous()
+ v_cdf_theta = v_cdf_theta.contiguous()
+
+ # quantiles of F_u evaluated in F_v^\theta
+ u_index = nx.searchsorted(u_cdf, v_cdf_theta)
+ u_icdf_theta = nx.take_along_axis(u_values, nx.clip(u_index, 0, n - 1), -1)
+
+ # Deal with 1
+ u_cdfm = nx.concatenate([u_cdf, nx.reshape(u_cdf[:, 0], (-1, 1)) + 1], axis=1)
+ u_valuesm = nx.concatenate([u_values, nx.reshape(u_values[:, 0], (-1, 1)) + 1], axis=1)
+
+ if nx.__name__ == 'torch':
+ # this is to ensure the best performance for torch searchsorted
+ # and avoid a warninng related to non-contiguous arrays
+ u_cdfm = u_cdfm.contiguous()
+ v_cdf_theta = v_cdf_theta.contiguous()
+
+ u_indexm = nx.searchsorted(u_cdfm, v_cdf_theta, side="right")
+ u_icdfm_theta = nx.take_along_axis(u_valuesm, nx.clip(u_indexm, 0, n), -1)
+
+ dCp = nx.sum(nx.power(nx.abs(u_icdf_theta - v_values[:, 1:]), p)
+ - nx.power(nx.abs(u_icdf_theta - v_values[:, :-1]), p), axis=-1)
+
+ dCm = nx.sum(nx.power(nx.abs(u_icdfm_theta - v_values[:, 1:]), p)
+ - nx.power(nx.abs(u_icdfm_theta - v_values[:, :-1]), p), axis=-1)
+
+ return dCp.reshape(-1, 1), dCm.reshape(-1, 1)
+
+
+def ot_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p):
+ r""" Computes the the cost (Equation (6.2) of [1])
+
+ Parameters
+ ----------
+ theta: array-like, shape (n_batch, n)
+ Cuts on the circle
+ u_values: array-like, shape (n_batch, n)
+ locations of the first empirical distribution
+ v_values: array-like, shape (n_batch, n)
+ locations of the second empirical distribution
+ u_cdf: array-like, shape (n_batch, n)
+ cdf of the first empirical distribution
+ v_cdf: array-like, shape (n_batch, n)
+ cdf of the second empirical distribution
+ p: float, optional = 2
+ Power p used for computing the Wasserstein distance
+
+ Returns
+ -------
+ ot_cost: array-like, shape (n_batch,)
+ OT cost evaluated at theta
+
+ References
+ ---------
+ .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258.
+ """
+ nx = get_backend(theta, u_values, v_values, u_cdf, v_cdf)
+
+ v_values = nx.copy(v_values)
+
+ m_batch, m = v_values.shape
+ n_batch, n = u_values.shape
+
+ v_cdf_theta = v_cdf - (theta - nx.floor(theta))
+
+ mask_p = v_cdf_theta >= 0
+ mask_n = v_cdf_theta < 0
+
+ v_values[mask_n] += nx.floor(theta)[mask_n] + 1
+ v_values[mask_p] += nx.floor(theta)[mask_p]
+
+ if nx.any(mask_n) and nx.any(mask_p):
+ v_cdf_theta[mask_n] += 1
+
+ # Put negative values at the end
+ v_cdf_theta2 = nx.copy(v_cdf_theta)
+ v_cdf_theta2[mask_n] = np.inf
+ shift = (-nx.argmin(v_cdf_theta2, axis=-1))
+
+ v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1, 1)))
+ v_values = roll_cols(v_values, nx.reshape(shift, (-1, 1)))
+ v_values = nx.concatenate([v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1)
+
+ # Compute absciss
+ cdf_axis = nx.sort(nx.concatenate((u_cdf, v_cdf_theta), -1), -1)
+ cdf_axis_pad = nx.zero_pad(cdf_axis, pad_width=[(0, 0), (1, 0)])
+
+ delta = cdf_axis_pad[..., 1:] - cdf_axis_pad[..., :-1]
+
+ if nx.__name__ == 'torch':
+ # this is to ensure the best performance for torch searchsorted
+ # and avoid a warninng related to non-contiguous arrays
+ u_cdf = u_cdf.contiguous()
+ v_cdf_theta = v_cdf_theta.contiguous()
+ cdf_axis = cdf_axis.contiguous()
+
+ # Compute icdf
+ u_index = nx.searchsorted(u_cdf, cdf_axis)
+ u_icdf = nx.take_along_axis(u_values, u_index.clip(0, n - 1), -1)
+
+ v_values = nx.concatenate([v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1)
+ v_index = nx.searchsorted(v_cdf_theta, cdf_axis)
+ v_icdf = nx.take_along_axis(v_values, v_index.clip(0, m), -1)
+
+ if p == 1:
+ ot_cost = nx.sum(delta * nx.abs(u_icdf - v_icdf), axis=-1)
+ else:
+ ot_cost = nx.sum(delta * nx.power(nx.abs(u_icdf - v_icdf), p), axis=-1)
+
+ return ot_cost
+
+
+def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1,
+ Lm=10, Lp=10, tm=-1, tp=1, eps=1e-6, require_sort=True,
+ log=False):
+ r"""Computes the Wasserstein distance on the circle using the Binary search algorithm proposed in [44].
+ Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`,
+ takes the value modulo 1.
+ If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates
+ using e.g. the atan2 function.
+
+ .. math::
+ W_p^p(u,v) = \inf_{\theta\in\mathbb{R}}\int_0^1 |F_u^{-1}(q) - (F_v-\theta)^{-1}(q)|^p\ \mathrm{d}q
+
+ where:
+
+ - :math:`F_u` and :math:`F_v` are respectively the cdfs of :math:`u` and :math:`v`
+
+ For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with
+
+ .. math::
+ u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi}
+
+ using e.g. ot.utils.get_coordinate_circle(x)
+
+ The function runs on backend but tensorflow is not supported.
+
+ Parameters
+ ----------
+ u_values : ndarray, shape (n, ...)
+ samples in the source domain (coordinates on [0,1[)
+ v_values : ndarray, shape (n, ...)
+ samples in the target domain (coordinates on [0,1[)
+ u_weights : ndarray, shape (n, ...), optional
+ samples weights in the source domain
+ v_weights : ndarray, shape (n, ...), optional
+ samples weights in the target domain
+ p : float, optional (default=1)
+ Power p used for computing the Wasserstein distance
+ Lm : int, optional
+ Lower bound dC
+ Lp : int, optional
+ Upper bound dC
+ tm: float, optional
+ Lower bound theta
+ tp: float, optional
+ Upper bound theta
+ eps: float, optional
+ Stopping condition
+ require_sort: bool, optional
+ If True, sort the values.
+ log: bool, optional
+ If True, returns also the optimal theta
+
+ Returns
+ -------
+ loss: float
+ Cost associated to the optimal transportation
+ log: dict, optional
+ log dictionary returned only if log==True in parameters
+
+ Examples
+ --------
+ >>> u = np.array([[0.2,0.5,0.8]])%1
+ >>> v = np.array([[0.4,0.5,0.7]])%1
+ >>> binary_search_circle(u.T, v.T, p=1)
+ array([0.1])
+
+ References
+ ----------
+ .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258.
+ .. Matlab Code: https://users.mccme.ru/ansobol/otarie/software.html
+ """
+ assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p)
+
+ if u_weights is not None and v_weights is not None:
+ nx = get_backend(u_values, v_values, u_weights, v_weights)
+ else:
+ nx = get_backend(u_values, v_values)
+
+ n = u_values.shape[0]
+ m = v_values.shape[0]
+
+ if len(u_values.shape) == 1:
+ u_values = nx.reshape(u_values, (n, 1))
+ if len(v_values.shape) == 1:
+ v_values = nx.reshape(v_values, (m, 1))
+
+ if u_values.shape[1] != v_values.shape[1]:
+ raise ValueError(
+ "u and v must have the same number of batchs {} and {} respectively given".format(u_values.shape[1],
+ v_values.shape[1]))
+
+ u_values = u_values % 1
+ v_values = v_values % 1
+
+ if u_weights is None:
+ u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values)
+ elif u_weights.ndim != u_values.ndim:
+ u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1)
+ if v_weights is None:
+ v_weights = nx.full(v_values.shape, 1. / m, type_as=v_values)
+ elif v_weights.ndim != v_values.ndim:
+ v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1)
+
+ if require_sort:
+ u_sorter = nx.argsort(u_values, 0)
+ u_values = nx.take_along_axis(u_values, u_sorter, 0)
+
+ v_sorter = nx.argsort(v_values, 0)
+ v_values = nx.take_along_axis(v_values, v_sorter, 0)
+
+ u_weights = nx.take_along_axis(u_weights, u_sorter, 0)
+ v_weights = nx.take_along_axis(v_weights, v_sorter, 0)
+
+ u_cdf = nx.cumsum(u_weights, 0).T
+ v_cdf = nx.cumsum(v_weights, 0).T
+
+ u_values = u_values.T
+ v_values = v_values.T
+
+ L = max(Lm, Lp)
+
+ tm = tm * nx.reshape(nx.ones((u_values.shape[0],), type_as=u_values), (-1, 1))
+ tm = nx.tile(tm, (1, m))
+ tp = tp * nx.reshape(nx.ones((u_values.shape[0],), type_as=u_values), (-1, 1))
+ tp = nx.tile(tp, (1, m))
+ tc = (tm + tp) / 2
+
+ done = nx.zeros((u_values.shape[0], m))
+
+ cpt = 0
+ while nx.any(1 - done):
+ cpt += 1
+
+ dCp, dCm = derivative_cost_on_circle(tc, u_values, v_values, u_cdf, v_cdf, p)
+ done = ((dCp * dCm) <= 0) * 1
+
+ mask = ((tp - tm) < eps / L) * (1 - done)
+
+ if nx.any(mask):
+ # can probably be improved by computing only relevant values
+ dCptp, dCmtp = derivative_cost_on_circle(tp, u_values, v_values, u_cdf, v_cdf, p)
+ dCptm, dCmtm = derivative_cost_on_circle(tm, u_values, v_values, u_cdf, v_cdf, p)
+ Ctm = ot_cost_on_circle(tm, u_values, v_values, u_cdf, v_cdf, p).reshape(-1, 1)
+ Ctp = ot_cost_on_circle(tp, u_values, v_values, u_cdf, v_cdf, p).reshape(-1, 1)
+
+ mask_end = mask * (nx.abs(dCptm - dCmtp) > 0.001)
+ tc[mask_end > 0] = ((Ctp - Ctm + tm * dCptm - tp * dCmtp) / (dCptm - dCmtp))[mask_end > 0]
+ done[nx.prod(mask, axis=-1) > 0] = 1
+ elif nx.any(1 - done):
+ tm[((1 - mask) * (dCp < 0)) > 0] = tc[((1 - mask) * (dCp < 0)) > 0]
+ tp[((1 - mask) * (dCp >= 0)) > 0] = tc[((1 - mask) * (dCp >= 0)) > 0]
+ tc[((1 - mask) * (1 - done)) > 0] = (tm[((1 - mask) * (1 - done)) > 0] + tp[((1 - mask) * (1 - done)) > 0]) / 2
+
+ w = ot_cost_on_circle(tc, u_values, v_values, u_cdf, v_cdf, p)
+
+ if log:
+ return w, {"optimal_theta": tc[:, 0]}
+ return w
+
+
+def wasserstein1_circle(u_values, v_values, u_weights=None, v_weights=None, require_sort=True):
+ r"""Computes the 1-Wasserstein distance on the circle using the level median [45].
+ Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`,
+ takes the value modulo 1.
+ If the values are on :math:`S^1\subset\mathbb{R}^2`, first find the coordinates
+ using e.g. the atan2 function.
+ The function runs on backend but tensorflow is not supported.
+
+ .. math::
+ W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t
+
+ Parameters
+ ----------
+ u_values : ndarray, shape (n, ...)
+ samples in the source domain (coordinates on [0,1[)
+ v_values : ndarray, shape (n, ...)
+ samples in the target domain (coordinates on [0,1[)
+ u_weights : ndarray, shape (n, ...), optional
+ samples weights in the source domain
+ v_weights : ndarray, shape (n, ...), optional
+ samples weights in the target domain
+ require_sort: bool, optional
+ If True, sort the values.
+
+ Returns
+ -------
+ loss: float
+ Cost associated to the optimal transportation
+
+ Examples
+ --------
+ >>> u = np.array([[0.2,0.5,0.8]])%1
+ >>> v = np.array([[0.4,0.5,0.7]])%1
+ >>> wasserstein1_circle(u.T, v.T)
+ array([0.1])
+
+ References
+ ----------
+ .. [45] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. "The statistics of circular optimal transport." Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82.
+ .. Code R: https://gitlab.gwdg.de/shundri/circularOT/-/tree/master/
+ """
+
+ if u_weights is not None and v_weights is not None:
+ nx = get_backend(u_values, v_values, u_weights, v_weights)
+ else:
+ nx = get_backend(u_values, v_values)
+
+ n = u_values.shape[0]
+ m = v_values.shape[0]
+
+ if len(u_values.shape) == 1:
+ u_values = nx.reshape(u_values, (n, 1))
+ if len(v_values.shape) == 1:
+ v_values = nx.reshape(v_values, (m, 1))
+
+ if u_values.shape[1] != v_values.shape[1]:
+ raise ValueError(
+ "u and v must have the same number of batchs {} and {} respectively given".format(u_values.shape[1],
+ v_values.shape[1]))
+
+ u_values = u_values % 1
+ v_values = v_values % 1
+
+ if u_weights is None:
+ u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values)
+ elif u_weights.ndim != u_values.ndim:
+ u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1)
+ if v_weights is None:
+ v_weights = nx.full(v_values.shape, 1. / m, type_as=v_values)
+ elif v_weights.ndim != v_values.ndim:
+ v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1)
+
+ if require_sort:
+ u_sorter = nx.argsort(u_values, 0)
+ u_values = nx.take_along_axis(u_values, u_sorter, 0)
+
+ v_sorter = nx.argsort(v_values, 0)
+ v_values = nx.take_along_axis(v_values, v_sorter, 0)
+
+ u_weights = nx.take_along_axis(u_weights, u_sorter, 0)
+ v_weights = nx.take_along_axis(v_weights, v_sorter, 0)
+
+ # Code inspired from https://gitlab.gwdg.de/shundri/circularOT/-/tree/master/
+ values_sorted, values_sorter = nx.sort2(nx.concatenate((u_values, v_values), 0), 0)
+
+ cdf_diff = nx.cumsum(nx.take_along_axis(nx.concatenate((u_weights, -v_weights), 0), values_sorter, 0), 0)
+ cdf_diff_sorted, cdf_diff_sorter = nx.sort2(cdf_diff, axis=0)
+
+ values_sorted = nx.zero_pad(values_sorted, pad_width=[(0, 1), (0, 0)], value=1)
+ delta = values_sorted[1:, ...] - values_sorted[:-1, ...]
+ weight_sorted = nx.take_along_axis(delta, cdf_diff_sorter, 0)
+
+ sum_weights = nx.cumsum(weight_sorted, axis=0) - 0.5
+ sum_weights[sum_weights < 0] = np.inf
+ inds = nx.argmin(sum_weights, axis=0)
+
+ levMed = nx.take_along_axis(cdf_diff_sorted, nx.reshape(inds, (1, -1)), 0)
+
+ return nx.sum(delta * nx.abs(cdf_diff - levMed), axis=0)
+
+
+def wasserstein_circle(u_values, v_values, u_weights=None, v_weights=None, p=1,
+ Lm=10, Lp=10, tm=-1, tp=1, eps=1e-6, require_sort=True):
+ r"""Computes the Wasserstein distance on the circle using either [45] for p=1 or
+ the binary search algorithm proposed in [44] otherwise.
+ Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`,
+ takes the value modulo 1.
+ If the values are on :math:`S^1\subset\mathbb{R}^2`, it requires to first find the coordinates
+ using e.g. the atan2 function.
+
+ General loss returned:
+
+ .. math::
+ OT_{loss} = \inf_{\theta\in\mathbb{R}}\int_0^1 |cdf_u^{-1}(q) - (cdf_v-\theta)^{-1}(q)|^p\ \mathrm{d}q
+
+ For p=1, [45]
+
+ .. math::
+ W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t
+
+ For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with
+
+ .. math::
+ u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi}
+
+ using e.g. ot.utils.get_coordinate_circle(x)
+
+ The function runs on backend but tensorflow is not supported.
+
+ Parameters
+ ----------
+ u_values : ndarray, shape (n, ...)
+ samples in the source domain (coordinates on [0,1[)
+ v_values : ndarray, shape (n, ...)
+ samples in the target domain (coordinates on [0,1[)
+ u_weights : ndarray, shape (n, ...), optional
+ samples weights in the source domain
+ v_weights : ndarray, shape (n, ...), optional
+ samples weights in the target domain
+ p : float, optional (default=1)
+ Power p used for computing the Wasserstein distance
+ Lm : int, optional
+ Lower bound dC. For p>1.
+ Lp : int, optional
+ Upper bound dC. For p>1.
+ tm: float, optional
+ Lower bound theta. For p>1.
+ tp: float, optional
+ Upper bound theta. For p>1.
+ eps: float, optional
+ Stopping condition. For p>1.
+ require_sort: bool, optional
+ If True, sort the values.
+
+ Returns
+ -------
+ loss: float
+ Cost associated to the optimal transportation
+
+ Examples
+ --------
+ >>> u = np.array([[0.2,0.5,0.8]])%1
+ >>> v = np.array([[0.4,0.5,0.7]])%1
+ >>> wasserstein_circle(u.T, v.T)
+ array([0.1])
+
+ References
+ ----------
+ .. [44] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. "The statistics of circular optimal transport." Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82.
+ .. [45] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258.
+ """
+ assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p)
+
+ if p == 1:
+ return wasserstein1_circle(u_values, v_values, u_weights, v_weights, require_sort)
+
+ return binary_search_circle(u_values, v_values, u_weights, v_weights,
+ p=p, Lm=Lm, Lp=Lp, tm=tm, tp=tp, eps=eps,
+ require_sort=require_sort)
+
+
+def semidiscrete_wasserstein2_unif_circle(u_values, u_weights=None):
+ r"""Computes the closed-form for the 2-Wasserstein distance between samples and a uniform distribution on :math:`S^1`
+ Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`,
+ takes the value modulo 1.
+ If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates
+ using e.g. the atan2 function.
+
+ .. math::
+ W_2^2(\mu_n, \nu) = \sum_{i=1}^n \alpha_i x_i^2 - \left(\sum_{i=1}^n \alpha_i x_i\right)^2 + \sum_{i=1}^n \alpha_i x_i \left(1-\alpha_i-2\sum_{k=1}^{i-1}\alpha_k\right) + \frac{1}{12}
+
+ where:
+
+ - :math:`\nu=\mathrm{Unif}(S^1)` and :math:`\mu_n = \sum_{i=1}^n \alpha_i \delta_{x_i}`
+
+ For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with
+
+ .. math::
+ u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi},
+
+ using e.g. ot.utils.get_coordinate_circle(x)
+
+ Parameters
+ ----------
+ u_values: ndarray, shape (n, ...)
+ Samples
+ u_weights : ndarray, shape (n, ...), optional
+ samples weights in the source domain
+
+ Returns
+ -------
+ loss: float
+ Cost associated to the optimal transportation
+
+ Examples
+ --------
+ >>> x0 = np.array([[0], [0.2], [0.4]])
+ >>> semidiscrete_wasserstein2_unif_circle(x0)
+ array([0.02111111])
+
+ References
+ ----------
+ .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations.
+ """
+
+ if u_weights is not None:
+ nx = get_backend(u_values, u_weights)
+ else:
+ nx = get_backend(u_values)
+
+ n = u_values.shape[0]
+
+ u_values = u_values % 1
+
+ if len(u_values.shape) == 1:
+ u_values = nx.reshape(u_values, (n, 1))
+
+ if u_weights is None:
+ u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values)
+ elif u_weights.ndim != u_values.ndim:
+ u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1)
+
+ u_values = nx.sort(u_values, 0)
+ u_cdf = nx.cumsum(u_weights, 0)
+ u_cdf = nx.zero_pad(u_cdf, [(1, 0), (0, 0)])
+
+ cpt1 = nx.sum(u_weights * u_values**2, axis=0)
+ u_mean = nx.sum(u_weights * u_values, axis=0)
+
+ ns = 1 - u_weights - 2 * u_cdf[:-1]
+ cpt2 = nx.sum(u_values * u_weights * ns, axis=0)
+
+ return cpt1 - u_mean**2 + cpt2 + 1 / 12