summaryrefslogtreecommitdiff
path: root/ot/lp
diff options
context:
space:
mode:
Diffstat (limited to 'ot/lp')
-rw-r--r--ot/lp/EMD.h2
-rw-r--r--ot/lp/EMD_wrapper.cpp9
-rw-r--r--ot/lp/__init__.py247
-rw-r--r--ot/lp/emd_wrap.pyx27
-rw-r--r--ot/lp/network_simplex_simple.h7
5 files changed, 250 insertions, 42 deletions
diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h
index f42e222..c0fe7a3 100644
--- a/ot/lp/EMD.h
+++ b/ot/lp/EMD.h
@@ -32,4 +32,6 @@ enum ProblemType {
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter);
+
+
#endif
diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp
index fc7ca63..bc873ed 100644
--- a/ot/lp/EMD_wrapper.cpp
+++ b/ot/lp/EMD_wrapper.cpp
@@ -17,13 +17,13 @@
int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
double* alpha, double* beta, double *cost, int maxIter) {
-// beware M and C anre strored in row major C style!!!
- int n, m, i, cur;
+ // beware M and C anre strored in row major C style!!!
+ int n, m, i, cur;
typedef FullBipartiteDigraph Digraph;
- DIGRAPH_TYPEDEFS(FullBipartiteDigraph);
+ DIGRAPH_TYPEDEFS(FullBipartiteDigraph);
- // Get the number of non zero coordinates for r and c
+ // Get the number of non zero coordinates for r and c
n=0;
for (int i=0; i<n1; i++) {
double val=*(X+i);
@@ -105,3 +105,4 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
return ret;
}
+
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 0c92810..514a607 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -2,8 +2,6 @@
"""
Solvers for the original linear program OT problem
-
-
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
@@ -12,22 +10,169 @@ Solvers for the original linear program OT problem
import multiprocessing
import sys
+
import numpy as np
from scipy.sparse import coo_matrix
-from .import cvx
-
+from . import cvx
+from .cvx import barycenter
# import compiled emd
from .emd_wrap import emd_c, check_result, emd_1d_sorted
-from ..utils import parmap
-from .cvx import barycenter
from ..utils import dist
+from ..utils import parmap
+
+__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx',
+ 'emd_1d', 'emd2_1d', 'wasserstein_1d']
+
+
+def center_ot_dual(alpha0, beta0, a=None, b=None):
+ r"""Center dual OT potentials w.r.t. theirs weights
+
+ The main idea of this function is to find unique dual potentials
+ that ensure some kind of centering/fairness. The main idea is to find dual potentials that lead to the same final objective value for both source and targets (see below for more details). It will help having
+ stability when multiple calling of the OT solver with small changes.
+
+ Basically we add another constraint to the potential that will not
+ change the objective value but will ensure unicity. The constraint
+ is the following:
+
+ .. math::
+ \alpha^T a= \beta^T b
+
+ in addition to the OT problem constraints.
+
+ since :math:`\sum_i a_i=\sum_j b_j` this can be solved by adding/removing
+ a constant from both :math:`\alpha_0` and :math:`\beta_0`.
+
+ .. math::
+ c=\frac{\beta0^T b-\alpha_0^T a}{1^Tb+1^Ta}
+
+ \alpha=\alpha_0+c
+
+ \beta=\beta0+c
+
+ Parameters
+ ----------
+ alpha0 : (ns,) numpy.ndarray, float64
+ Source dual potential
+ beta0 : (nt,) numpy.ndarray, float64
+ Target dual potential
+ a : (ns,) numpy.ndarray, float64
+ Source histogram (uniform weight if empty list)
+ b : (nt,) numpy.ndarray, float64
+ Target histogram (uniform weight if empty list)
+
+ Returns
+ -------
+ alpha : (ns,) numpy.ndarray, float64
+ Source centered dual potential
+ beta : (nt,) numpy.ndarray, float64
+ Target centered dual potential
+
+ """
+ # if no weights are provided, use uniform
+ if a is None:
+ a = np.ones(alpha0.shape[0]) / alpha0.shape[0]
+ if b is None:
+ b = np.ones(beta0.shape[0]) / beta0.shape[0]
+
+ # compute constant that balances the weighted sums of the duals
+ c = (b.dot(beta0) - a.dot(alpha0)) / (a.sum() + b.sum())
+
+ # update duals
+ alpha = alpha0 + c
+ beta = beta0 - c
+
+ return alpha, beta
+
+
+def estimate_dual_null_weights(alpha0, beta0, a, b, M):
+ r"""Estimate feasible values for 0-weighted dual potentials
+
+ The feasible values are computed efficiently but rather coarsely.
+
+ .. warning::
+ This function is necessary because the C++ solver in emd_c
+ discards all samples in the distributions with
+ zeros weights. This means that while the primal variable (transport
+ matrix) is exact, the solver only returns feasible dual potentials
+ on the samples with weights different from zero.
+
+ First we compute the constraints violations:
+
+ .. math::
+ V=\alpha+\beta^T-M
+
+ Next we compute the max amount of violation per row (alpha) and
+ columns (beta)
-__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx',
- 'emd_1d', 'emd2_1d', 'wasserstein_1d']
+ .. math::
+ v^a_i=\max_j V_{i,j}
+
+ v^b_j=\max_i V_{i,j}
+
+ Finally we update the dual potential with 0 weights if a
+ constraint is violated
+
+ .. math::
+ \alpha_i = \alpha_i -v^a_i \quad \text{ if } a_i=0 \text{ and } v^a_i>0
+
+ \beta_j = \beta_j -v^b_j \quad \text{ if } b_j=0 \text{ and } v^b_j>0
+
+ In the end the dual potentials are centered using function
+ :ref:`center_ot_dual`.
+
+ Note that all those updates do not change the objective value of the
+ solution but provide dual potentials that do not violate the constraints.
+
+ Parameters
+ ----------
+ alpha0 : (ns,) numpy.ndarray, float64
+ Source dual potential
+ beta0 : (nt,) numpy.ndarray, float64
+ Target dual potential
+ alpha0 : (ns,) numpy.ndarray, float64
+ Source dual potential
+ beta0 : (nt,) numpy.ndarray, float64
+ Target dual potential
+ a : (ns,) numpy.ndarray, float64
+ Source distribution (uniform weights if empty list)
+ b : (nt,) numpy.ndarray, float64
+ Target distribution (uniform weights if empty list)
+ M : (ns,nt) numpy.ndarray, float64
+ Loss matrix (c-order array with type float64)
+
+ Returns
+ -------
+ alpha : (ns,) numpy.ndarray, float64
+ Source corrected dual potential
+ beta : (nt,) numpy.ndarray, float64
+ Target corrected dual potential
+
+ """
+
+ # binary indexing of non-zeros weights
+ asel = a != 0
+ bsel = b != 0
+
+ # compute dual constraints violation
+ constraint_violation = alpha0[:, None] + beta0[None, :] - M
+
+ # Compute largest violation per line and columns
+ aviol = np.max(constraint_violation, 1)
+ bviol = np.max(constraint_violation, 0)
+ # update corrects violation of
+ alpha_up = -1 * ~asel * np.maximum(aviol, 0)
+ beta_up = -1 * ~bsel * np.maximum(bviol, 0)
-def emd(a, b, M, numItermax=100000, log=False):
+ alpha = alpha0 + alpha_up
+ beta = beta0 + beta_up
+
+ return center_ot_dual(alpha, beta, a, b)
+
+
+def emd(a, b, M, numItermax=100000, log=False, center_dual=True):
r"""Solves the Earth Movers distance problem and returns the OT matrix
@@ -35,7 +180,9 @@ def emd(a, b, M, numItermax=100000, log=False):
\gamma = arg\min_\gamma <\gamma,M>_F
s.t. \gamma 1 = a
+
\gamma^T 1= b
+
\gamma\geq 0
where :
@@ -43,7 +190,7 @@ def emd(a, b, M, numItermax=100000, log=False):
- a and b are the sample weights
.. warning::
- Note that the M matrix needs to be a C-order numpy.array in float64
+ Note that the M matrix needs to be a C-order numpy.array in float64
format.
Uses the algorithm proposed in [1]_
@@ -62,6 +209,9 @@ def emd(a, b, M, numItermax=100000, log=False):
log: bool, optional (default=False)
If True, returns a dictionary containing the cost and dual
variables. Otherwise returns only the optimal transportation matrix.
+ center_dual: boolean, optional (default=True)
+ If True, centers the dual potential using function
+ :ref:`center_ot_dual`.
Returns
-------
@@ -109,7 +259,20 @@ def emd(a, b, M, numItermax=100000, log=False):
if len(b) == 0:
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+ assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \
+ "Dimension mismatch, check dimensions of M with a and b"
+
+ asel = a != 0
+ bsel = b != 0
+
G, cost, u, v, result_code = emd_c(a, b, M, numItermax)
+
+ if center_dual:
+ u, v = center_ot_dual(u, v, a, b)
+
+ if np.any(~asel) or np.any(~bsel):
+ u, v = estimate_dual_null_weights(u, v, a, b, M)
+
result_code_string = check_result(result_code)
if log:
log = {}
@@ -123,14 +286,17 @@ def emd(a, b, M, numItermax=100000, log=False):
def emd2(a, b, M, processes=multiprocessing.cpu_count(),
- numItermax=100000, log=False, return_matrix=False):
+ numItermax=100000, log=False, return_matrix=False,
+ center_dual=True):
r"""Solves the Earth Movers distance problem and returns the loss
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F
+ \min_\gamma <\gamma,M>_F
s.t. \gamma 1 = a
+
\gamma^T 1= b
+
\gamma\geq 0
where :
@@ -138,7 +304,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
- a and b are the sample weights
.. warning::
- Note that the M matrix needs to be a C-order numpy.array in float64
+ Note that the M matrix needs to be a C-order numpy.array in float64
format.
Uses the algorithm proposed in [1]_
@@ -161,6 +327,9 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
variables. Otherwise returns only the optimal transportation cost.
return_matrix: boolean, optional (default=False)
If True, returns the optimal transportation matrix in the log.
+ center_dual: boolean, optional (default=True)
+ If True, centers the dual potential using function
+ :ref:`center_ot_dual`.
Returns
-------
@@ -204,7 +373,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
# problem with pikling Forks
if sys.platform.endswith('win32'):
- processes=1
+ processes = 1
# if empty array given then use uniform distributions
if len(a) == 0:
@@ -212,21 +381,43 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
if len(b) == 0:
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+ assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \
+ "Dimension mismatch, check dimensions of M with a and b"
+
+ asel = a != 0
+
if log or return_matrix:
def f(b):
- G, cost, u, v, resultCode = emd_c(a, b, M, numItermax)
- result_code_string = check_result(resultCode)
+ bsel = b != 0
+
+ G, cost, u, v, result_code = emd_c(a, b, M, numItermax)
+
+ if center_dual:
+ u, v = center_ot_dual(u, v, a, b)
+
+ if np.any(~asel) or np.any(~bsel):
+ u, v = estimate_dual_null_weights(u, v, a, b, M)
+
+ result_code_string = check_result(result_code)
log = {}
if return_matrix:
log['G'] = G
log['u'] = u
log['v'] = v
log['warning'] = result_code_string
- log['result_code'] = resultCode
+ log['result_code'] = result_code
return [cost, log]
else:
def f(b):
+ bsel = b != 0
G, cost, u, v, result_code = emd_c(a, b, M, numItermax)
+
+ if center_dual:
+ u, v = center_ot_dual(u, v, a, b)
+
+ if np.any(~asel) or np.any(~bsel):
+ u, v = estimate_dual_null_weights(u, v, a, b, M)
+
check_result(result_code)
return cost
@@ -234,7 +425,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
return f(b)
nb = b.shape[1]
- if processes>1:
+ if processes > 1:
res = parmap(f, [b[:, i] for i in range(nb)], processes)
else:
res = list(map(f, [b[:, i].copy() for i in range(nb)]))
@@ -242,8 +433,8 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
return res
-
-def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=None):
+def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100,
+ stopThr=1e-7, verbose=False, log=None):
"""
Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance)
@@ -295,7 +486,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
k = X_init.shape[0]
d = X_init.shape[1]
if b is None:
- b = np.ones((k,))/k
+ b = np.ones((k,)) / k
if weights is None:
weights = np.ones((N,)) / N
@@ -306,17 +497,17 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
displacement_square_norm = stopThr + 1.
- while ( displacement_square_norm > stopThr and iter_count < numItermax ):
+ while (displacement_square_norm > stopThr and iter_count < numItermax):
T_sum = np.zeros((k, d))
- for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()):
-
+ for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights,
+ weights.tolist()):
M_i = dist(X, measure_locations_i)
T_i = emd(b, measure_weights_i, M_i)
T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i)
- displacement_square_norm = np.sum(np.square(T_sum-X))
+ displacement_square_norm = np.sum(np.square(T_sum - X))
if log:
displacement_square_norms.append(displacement_square_norm)
@@ -436,12 +627,12 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
if b.ndim == 0 or len(b) == 0:
b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0]
- x_a_1d = x_a.reshape((-1, ))
- x_b_1d = x_b.reshape((-1, ))
+ x_a_1d = x_a.reshape((-1,))
+ x_b_1d = x_b.reshape((-1,))
perm_a = np.argsort(x_a_1d)
perm_b = np.argsort(x_b_1d)
- G_sorted, indices, cost = emd_1d_sorted(a, b,
+ G_sorted, indices, cost = emd_1d_sorted(a[perm_a], b[perm_b],
x_a_1d[perm_a], x_b_1d[perm_b],
metric=metric, p=p)
G = coo_matrix((G_sorted, (perm_a[indices[:, 0]], perm_b[indices[:, 1]])),
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index 2b6c495..c167964 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -19,7 +19,7 @@ import warnings
cdef extern from "EMD.h":
- int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter)
+ int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter) nogil
cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED
@@ -35,8 +35,7 @@ def check_result(result_code):
message = "numItermax reached before optimality. Try to increase numItermax."
warnings.warn(message)
return message
-
-
+
@cython.boundscheck(False)
@cython.wraparound(False)
def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter):
@@ -61,6 +60,12 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
.. warning::
Note that the M matrix needs to be a C-order :py.cls:`numpy.array`
+ .. warning::
+ The C++ solver discards all samples in the distributions with
+ zeros weights. This means that while the primal variable (transport
+ matrix) is exact, the solver only returns feasible dual potentials
+ on the samples with weights different from zero.
+
Parameters
----------
a : (ns,) numpy.ndarray, float64
@@ -73,7 +78,6 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
The maximum number of iterations before stopping the optimization
algorithm if it has not converged.
-
Returns
-------
gamma: (ns x nt) numpy.ndarray
@@ -82,12 +86,19 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
"""
cdef int n1= M.shape[0]
cdef int n2= M.shape[1]
+ cdef int nmax=n1+n2-1
+ cdef int result_code = 0
+ cdef int nG=0
cdef double cost=0
- cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2])
cdef np.ndarray[double, ndim=1, mode="c"] alpha=np.zeros(n1)
cdef np.ndarray[double, ndim=1, mode="c"] beta=np.zeros(n2)
+ cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([0, 0])
+
+ cdef np.ndarray[double, ndim=1, mode="c"] Gv=np.zeros(0)
+ cdef np.ndarray[long, ndim=1, mode="c"] iG=np.zeros(0,dtype=np.int)
+ cdef np.ndarray[long, ndim=1, mode="c"] jG=np.zeros(0,dtype=np.int)
if not len(a):
a=np.ones((n1,))/n1
@@ -95,8 +106,12 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
if not len(b):
b=np.ones((n2,))/n2
+ # init OT matrix
+ G=np.zeros([n1, n2])
+
# calling the function
- cdef int result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
+ with nogil:
+ result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
return G, cost, alpha, beta, result_code
diff --git a/ot/lp/network_simplex_simple.h b/ot/lp/network_simplex_simple.h
index 7c6a4ce..5d93040 100644
--- a/ot/lp/network_simplex_simple.h
+++ b/ot/lp/network_simplex_simple.h
@@ -686,7 +686,7 @@ namespace lemon {
/// \see resetParams(), reset()
ProblemType run() {
#if DEBUG_LVL>0
- std::cout << "OPTIMAL = " << OPTIMAL << "\nINFEASIBLE = " << INFEASIBLE << "\nUNBOUNDED = " << UNBOUNDED << "\nMAX_ITER_REACHED" << MAX_ITER_REACHED\n";
+ std::cout << "OPTIMAL = " << OPTIMAL << "\nINFEASIBLE = " << INFEASIBLE << "\nUNBOUNDED = " << UNBOUNDED << "\nMAX_ITER_REACHED" << MAX_ITER_REACHED << "\n" ;
#endif
if (!init()) return INFEASIBLE;
@@ -875,7 +875,7 @@ namespace lemon {
c += Number(it->second) * Number(_cost[it->first]);
return c;*/
- for (int i=0; i<_flow.size(); i++)
+ for (unsigned long i=0; i<_flow.size(); i++)
c += _flow[i] * Number(_cost[i]);
return c;
@@ -1257,7 +1257,7 @@ namespace lemon {
u = w;
}
_pred[u_in] = in_arc;
- _forward[u_in] = (u_in == _source[in_arc]);
+ _forward[u_in] = ((unsigned int)u_in == _source[in_arc]);
_succ_num[u_in] = old_succ_num;
// Set limits for updating _last_succ form v_in and v_out
@@ -1418,7 +1418,6 @@ namespace lemon {
template <typename PivotRuleImpl>
ProblemType start() {
PivotRuleImpl pivot(*this);
- double prevCost=-1;
ProblemType retVal = OPTIMAL;
// Perform heuristic initial pivots