diff options
author | kguerda-idris <84066930+kguerda-idris@users.noreply.github.com> | 2021-09-29 15:29:31 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-09-29 15:29:31 +0200 |
commit | 1c7e7ce2da8bb362c184fb6eae71fe7e36356494 (patch) | |
tree | 92fdc31870b6d5384c8ba83ff72d85a2d5a1eee6 /ot/lp/__init__.py | |
parent | 7dde9e8e4b6aae756e103d49198caaa4f24150e3 (diff) |
[MRG] OpenMP support (#260)
* Added : OpenMP support
Restored : Epsilon and Debug mode
Replaced : parmap => multiprocessing is now replace by multithreading
* Commit clean up
* Number of CPUs correctly calculated on SLURM clusters
* Corrected number of processes for cluster slurm
* Mistake corrected
* parmap is now deprecated
* Now a different solver is used depending on the requested number of threads
* Tiny mistake corrected
* Folders are now in the ot library instead of at the root
* Helpers is now correctly placed
* Attempt to make compilation work smoothly
* OS compatible path
* NumThreads now defaults to 1
* Better flags
* Mistake corrected in case of OpenMP unavailability
* Revert OpenMP flags modification, which do not compile on Windows
* Test helper functions
* Helpers comments
* Documentation update
* File title corrected
* Warning no longer using print
* Last attempt for macos compilation
* pls work
* atempt
* solving a type error
* TypeError OpenMP
* Compilation finally working on Windows
* Bug solve, number of threads now correctly selected
* 64 bits solver to avoid overflows for bigger problems
* 64 bits EMD corrected
Co-authored-by: kguerda-idris <ssos023@jean-zay3.idris.fr>
Co-authored-by: ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com>
Co-authored-by: ncassereau <nathan.cassereau@idris.fr>
Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r-- | ot/lp/__init__.py | 71 |
1 files changed, 52 insertions, 19 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index c8c9da6..b907b10 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -8,11 +8,13 @@ Solvers for the original linear program OT problem # # License: MIT License +import os import multiprocessing import sys import numpy as np from scipy.sparse import coo_matrix +import warnings from . import cvx from .cvx import barycenter @@ -25,9 +27,28 @@ from ..backend import get_backend __all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', 'emd_1d', 'emd2_1d', 'wasserstein_1d'] +def check_number_threads(numThreads): + """Checks whether or not the requested number of threads has a valid value. + + Parameters + ---------- + numThreads : int or str + The requested number of threads, should either be a strictly positive integer or "max" or None + + Returns + ------- + numThreads : int + Corrected number of threads + """ + if (numThreads is None) or (isinstance(numThreads, str) and numThreads.lower() == 'max'): + return -1 + if (not isinstance(numThreads, int)) or numThreads < 1: + raise ValueError('numThreads should either be "max" or a strictly positive integer') + return numThreads + def center_ot_dual(alpha0, beta0, a=None, b=None): - r"""Center dual OT potentials w.r.t. theirs weights + r"""Center dual OT potentials w.r.t. their weights The main idea of this function is to find unique dual potentials that ensure some kind of centering/fairness. The main idea is to find dual potentials that lead to the same final objective value for both source and targets (see below for more details). It will help having @@ -173,7 +194,7 @@ def estimate_dual_null_weights(alpha0, beta0, a, b, M): return center_ot_dual(alpha, beta, a, b) -def emd(a, b, M, numItermax=100000, log=False, center_dual=True): +def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): r"""Solves the Earth Movers distance problem and returns the OT matrix @@ -215,6 +236,9 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True): center_dual: boolean, optional (default=True) If True, centers the dual potential using function :ref:`center_ot_dual`. + numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) + If compiled with OpenMP, chooses the number of threads to parallelize. + "max" selects the highest number possible. Returns ------- @@ -285,7 +309,9 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True): asel = a != 0 bsel = b != 0 - G, cost, u, v, result_code = emd_c(a, b, M, numItermax) + numThreads = check_number_threads(numThreads) + + G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) if center_dual: u, v = center_ot_dual(u, v, a, b) @@ -305,9 +331,9 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True): return nx.from_numpy(G, type_as=M0) -def emd2(a, b, M, processes=multiprocessing.cpu_count(), +def emd2(a, b, M, processes=1, numItermax=100000, log=False, return_matrix=False, - center_dual=True): + center_dual=True, numThreads=1): r"""Solves the Earth Movers distance problem and returns the loss .. math:: @@ -336,8 +362,8 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), Target histogram (uniform weight if empty list) M : (ns,nt) array-like, float64 Loss matrix (for numpy c-order array with type float64) - processes : int, optional (default=nb cpu) - Nb of processes used for multiple emd computation (not used on windows) + processes : int, optional (default=1) + Nb of processes used for multiple emd computation (deprecated) numItermax : int, optional (default=100000) The maximum number of iterations before stopping the optimization algorithm if it has not converged. @@ -349,6 +375,9 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), center_dual: boolean, optional (default=True) If True, centers the dual potential using function :ref:`center_ot_dual`. + numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) + If compiled with OpenMP, chooses the number of threads to parallelize. + "max" selects the highest number possible. Returns ------- @@ -390,7 +419,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), a0, b0, M0 = a, b, M nx = get_backend(M0, a0, b0) - + # convert to numpy M = nx.to_numpy(M) a = nx.to_numpy(a) @@ -400,10 +429,6 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), b = np.asarray(b, dtype=np.float64) M = np.asarray(M, dtype=np.float64, order= 'C') - # problem with pikling Forks - if sys.platform.endswith('win32') or not nx.__name__ == 'numpy': - processes = 1 - # if empty array given then use uniform distributions if len(a) == 0: a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] @@ -415,11 +440,13 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), asel = a != 0 + numThreads = check_number_threads(numThreads) + if log or return_matrix: def f(b): bsel = b != 0 - G, cost, u, v, result_code = emd_c(a, b, M, numItermax) + G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) if center_dual: u, v = center_ot_dual(u, v, a, b) @@ -442,7 +469,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), else: def f(b): bsel = b != 0 - G, cost, u, v, result_code = emd_c(a, b, M, numItermax) + G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) if center_dual: u, v = center_ot_dual(u, v, a, b) @@ -463,15 +490,17 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), nb = b.shape[1] if processes > 1: - res = parmap(f, [b[:, i].copy() for i in range(nb)], processes) - else: - res = list(map(f, [b[:, i].copy() for i in range(nb)])) + warnings.warn( + "The 'processes' parameter has been deprecated. " + "Multiprocessing should be done outside of POT." + ) + res = list(map(f, [b[:, i].copy() for i in range(nb)])) return res def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, - stopThr=1e-7, verbose=False, log=None): + stopThr=1e-7, verbose=False, log=None, numThreads=1): r""" Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance), formally: @@ -512,6 +541,10 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None Print information along iterations log : bool, optional record log if True + numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) + If compiled with OpenMP, chooses the number of threads to parallelize. + "max" selects the highest number possible. + Returns ------- @@ -551,7 +584,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()): M_i = dist(X, measure_locations_i) - T_i = emd(b, measure_weights_i, M_i) + T_i = emd(b, measure_weights_i, M_i, numThreads=numThreads) T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i) displacement_square_norm = np.sum(np.square(T_sum - X)) |