summaryrefslogtreecommitdiff
path: root/ot/lp/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r--ot/lp/__init__.py71
1 files changed, 52 insertions, 19 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index c8c9da6..b907b10 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -8,11 +8,13 @@ Solvers for the original linear program OT problem
#
# License: MIT License
+import os
import multiprocessing
import sys
import numpy as np
from scipy.sparse import coo_matrix
+import warnings
from . import cvx
from .cvx import barycenter
@@ -25,9 +27,28 @@ from ..backend import get_backend
__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx',
'emd_1d', 'emd2_1d', 'wasserstein_1d']
+def check_number_threads(numThreads):
+ """Checks whether or not the requested number of threads has a valid value.
+
+ Parameters
+ ----------
+ numThreads : int or str
+ The requested number of threads, should either be a strictly positive integer or "max" or None
+
+ Returns
+ -------
+ numThreads : int
+ Corrected number of threads
+ """
+ if (numThreads is None) or (isinstance(numThreads, str) and numThreads.lower() == 'max'):
+ return -1
+ if (not isinstance(numThreads, int)) or numThreads < 1:
+ raise ValueError('numThreads should either be "max" or a strictly positive integer')
+ return numThreads
+
def center_ot_dual(alpha0, beta0, a=None, b=None):
- r"""Center dual OT potentials w.r.t. theirs weights
+ r"""Center dual OT potentials w.r.t. their weights
The main idea of this function is to find unique dual potentials
that ensure some kind of centering/fairness. The main idea is to find dual potentials that lead to the same final objective value for both source and targets (see below for more details). It will help having
@@ -173,7 +194,7 @@ def estimate_dual_null_weights(alpha0, beta0, a, b, M):
return center_ot_dual(alpha, beta, a, b)
-def emd(a, b, M, numItermax=100000, log=False, center_dual=True):
+def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
r"""Solves the Earth Movers distance problem and returns the OT matrix
@@ -215,6 +236,9 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True):
center_dual: boolean, optional (default=True)
If True, centers the dual potential using function
:ref:`center_ot_dual`.
+ numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
+ If compiled with OpenMP, chooses the number of threads to parallelize.
+ "max" selects the highest number possible.
Returns
-------
@@ -285,7 +309,9 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True):
asel = a != 0
bsel = b != 0
- G, cost, u, v, result_code = emd_c(a, b, M, numItermax)
+ numThreads = check_number_threads(numThreads)
+
+ G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads)
if center_dual:
u, v = center_ot_dual(u, v, a, b)
@@ -305,9 +331,9 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True):
return nx.from_numpy(G, type_as=M0)
-def emd2(a, b, M, processes=multiprocessing.cpu_count(),
+def emd2(a, b, M, processes=1,
numItermax=100000, log=False, return_matrix=False,
- center_dual=True):
+ center_dual=True, numThreads=1):
r"""Solves the Earth Movers distance problem and returns the loss
.. math::
@@ -336,8 +362,8 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
Target histogram (uniform weight if empty list)
M : (ns,nt) array-like, float64
Loss matrix (for numpy c-order array with type float64)
- processes : int, optional (default=nb cpu)
- Nb of processes used for multiple emd computation (not used on windows)
+ processes : int, optional (default=1)
+ Nb of processes used for multiple emd computation (deprecated)
numItermax : int, optional (default=100000)
The maximum number of iterations before stopping the optimization
algorithm if it has not converged.
@@ -349,6 +375,9 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
center_dual: boolean, optional (default=True)
If True, centers the dual potential using function
:ref:`center_ot_dual`.
+ numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
+ If compiled with OpenMP, chooses the number of threads to parallelize.
+ "max" selects the highest number possible.
Returns
-------
@@ -390,7 +419,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
a0, b0, M0 = a, b, M
nx = get_backend(M0, a0, b0)
-
+
# convert to numpy
M = nx.to_numpy(M)
a = nx.to_numpy(a)
@@ -400,10 +429,6 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
b = np.asarray(b, dtype=np.float64)
M = np.asarray(M, dtype=np.float64, order= 'C')
- # problem with pikling Forks
- if sys.platform.endswith('win32') or not nx.__name__ == 'numpy':
- processes = 1
-
# if empty array given then use uniform distributions
if len(a) == 0:
a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
@@ -415,11 +440,13 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
asel = a != 0
+ numThreads = check_number_threads(numThreads)
+
if log or return_matrix:
def f(b):
bsel = b != 0
- G, cost, u, v, result_code = emd_c(a, b, M, numItermax)
+ G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads)
if center_dual:
u, v = center_ot_dual(u, v, a, b)
@@ -442,7 +469,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
else:
def f(b):
bsel = b != 0
- G, cost, u, v, result_code = emd_c(a, b, M, numItermax)
+ G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads)
if center_dual:
u, v = center_ot_dual(u, v, a, b)
@@ -463,15 +490,17 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
nb = b.shape[1]
if processes > 1:
- res = parmap(f, [b[:, i].copy() for i in range(nb)], processes)
- else:
- res = list(map(f, [b[:, i].copy() for i in range(nb)]))
+ warnings.warn(
+ "The 'processes' parameter has been deprecated. "
+ "Multiprocessing should be done outside of POT."
+ )
+ res = list(map(f, [b[:, i].copy() for i in range(nb)]))
return res
def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100,
- stopThr=1e-7, verbose=False, log=None):
+ stopThr=1e-7, verbose=False, log=None, numThreads=1):
r"""
Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance), formally:
@@ -512,6 +541,10 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
Print information along iterations
log : bool, optional
record log if True
+ numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
+ If compiled with OpenMP, chooses the number of threads to parallelize.
+ "max" selects the highest number possible.
+
Returns
-------
@@ -551,7 +584,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights,
weights.tolist()):
M_i = dist(X, measure_locations_i)
- T_i = emd(b, measure_weights_i, M_i)
+ T_i = emd(b, measure_weights_i, M_i, numThreads=numThreads)
T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i)
displacement_square_norm = np.sum(np.square(T_sum - X))