summaryrefslogtreecommitdiff
path: root/ot/lp/__init__.py
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2021-06-01 10:10:54 +0200
committerGitHub <noreply@github.com>2021-06-01 10:10:54 +0200
commit184f8f4f7ac78f1dd7f653496d2753211a4e3426 (patch)
tree483a7274c91030fd644de49b03a5fad04af9deba /ot/lp/__init__.py
parent1f16614954e2522fbdb1598c5b1f5c3630c68472 (diff)
[MRG] POT numpy/torch/jax backends (#249)
* add numpy and torch backends * stat sets on functions * proper import * install recent torch on windows * install recent torch on windows * now testing all functions in backedn * add jax backedn * clenaup windowds * proper convert for jax backedn * pep8 * try again windows tests * test jax conversion * try proper widows tests * emd fuction ses backedn * better test partial OT * proper tests to_numpy and teplate Backend * pep8 * pep8 x2 * feaking sinkhorn works with torch * sinkhorn2 compatible * working ot.emd2 * important detach * it should work * jax autodiff emd * pep8 * no tast same for jax * new independat tests per backedn * freaking pep8 * add tests for gradients * deprecate ot.gpu * worging dist function * working dist * dist done in backedn * not in * remove indexing * change accuacy for jax * first pull backend * projection simplex * projection simplex * projection simplex * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 * add backedn discusion to quickstart guide * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 + better doc * proper links * corect doctest * big debug documentation * doctest again * doctest again bis * doctest again ter (last one or i kill myself) * backend test + doc proj simplex * correction test_utils * correction test_utils * correction cumsum * correction flip * correction flip v2 * more debug * more debug * more debug + pep8 * pep8 * argh * proj_simplex * backedn works for sort * proj simplex * jax sucks * update doc * Update test/test_utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update test/test_utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update ot/utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update ot/lp/__init__.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * begin comment alex * comment alex part 2 * optimize test gromov * proj_simplex on vectors * add awesome gradient decsnt example on the weights * pep98 of course * proof read example by alex * pep8 again * encoding oos in translation * correct legend Co-authored-by: Nicolas Courty <ncourty@irisa.fr> Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r--ot/lp/__init__.py137
1 files changed, 89 insertions, 48 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index d5c3a5e..c8c9da6 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -18,8 +18,9 @@ from . import cvx
from .cvx import barycenter
# import compiled emd
from .emd_wrap import emd_c, check_result, emd_1d_sorted
-from ..utils import dist
+from ..utils import dist, list_to_array
from ..utils import parmap
+from ..backend import get_backend
__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx',
'emd_1d', 'emd2_1d', 'wasserstein_1d']
@@ -176,8 +177,7 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True):
r"""Solves the Earth Movers distance problem and returns the OT matrix
- .. math::
- \gamma = arg\min_\gamma <\gamma,M>_F
+ .. math:: \gamma = arg\min_\gamma <\gamma,M>_F
s.t. \gamma 1 = a
@@ -189,37 +189,41 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True):
- M is the metric cost matrix
- a and b are the sample weights
- .. warning::
- Note that the M matrix needs to be a C-order numpy.array in float64
- format.
+ .. warning:: Note that the M matrix in numpy needs to be a C-order
+ numpy.array in float64 format. It will be converted if not in this
+ format
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
Uses the algorithm proposed in [1]_
Parameters
----------
- a : (ns,) numpy.ndarray, float64
+ a : (ns,) array-like, float
Source histogram (uniform weight if empty list)
- b : (nt,) numpy.ndarray, float64
- Target histogram (uniform weight if empty list)
- M : (ns,nt) numpy.ndarray, float64
- Loss matrix (c-order array with type float64)
- numItermax : int, optional (default=100000)
+ b : (nt,) array-like, float
+ Target histogram (uniform weight if empty list)
+ M : (ns,nt) array-like, float
+ Loss matrix (c-order array in numpy with type float64)
+ numItermax : int, optional (default=100000)
The maximum number of iterations before stopping the optimization
- algorithm if it has not converged.
- log: bool, optional (default=False)
- If True, returns a dictionary containing the cost and dual
- variables. Otherwise returns only the optimal transportation matrix.
+ algorithm if it has not converged.
+ log: bool, optional (default=False)
+ If True, returns a dictionary containing the cost and dual variables.
+ Otherwise returns only the optimal transportation matrix.
center_dual: boolean, optional (default=True)
- If True, centers the dual potential using function
+ If True, centers the dual potential using function
:ref:`center_ot_dual`.
Returns
-------
- gamma: (ns x nt) numpy.ndarray
- Optimal transportation matrix for the given parameters
- log: dict
- If input log is true, a dictionary containing the cost and dual
- variables and exit status
+ gamma: array-like, shape (ns, nt)
+ Optimal transportation matrix for the given
+ parameters
+ log: dict, optional
+ If input log is true, a dictionary containing the
+ cost and dual variables and exit status
Examples
@@ -232,26 +236,37 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True):
>>> a=[.5,.5]
>>> b=[.5,.5]
>>> M=[[0.,1.],[1.,0.]]
- >>> ot.emd(a,b,M)
+ >>> ot.emd(a, b, M)
array([[0.5, 0. ],
[0. , 0.5]])
References
----------
- .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W.
- (2011, December). Displacement interpolation using Lagrangian mass
- transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p.
- 158). ACM.
+ .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011,
+ December). Displacement interpolation using Lagrangian mass transport.
+ In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM.
See Also
--------
- ot.bregman.sinkhorn : Entropic regularized OT
- ot.optim.cg : General regularized OT"""
-
+ ot.bregman.sinkhorn : Entropic regularized OT ot.optim.cg : General
+ regularized OT"""
+
+ # convert to numpy if list
+ a, b, M = list_to_array(a, b, M)
+
+ a0, b0, M0 = a, b, M
+ nx = get_backend(M0, a0, b0)
+
+ # convert to numpy
+ M = nx.to_numpy(M)
+ a = nx.to_numpy(a)
+ b = nx.to_numpy(b)
+
+ # ensure float64
a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ M = np.asarray(M, dtype=np.float64, order='C')
# if empty array given then use uniform distributions
if len(a) == 0:
@@ -262,6 +277,11 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True):
assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \
"Dimension mismatch, check dimensions of M with a and b"
+ # ensure that same mass
+ np.testing.assert_almost_equal(a.sum(0),
+ b.sum(0), err_msg='a and b vector must have the same sum')
+ b=b*a.sum()/b.sum()
+
asel = a != 0
bsel = b != 0
@@ -277,12 +297,12 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True):
if log:
log = {}
log['cost'] = cost
- log['u'] = u
- log['v'] = v
+ log['u'] = nx.from_numpy(u, type_as=a0)
+ log['v'] = nx.from_numpy(v, type_as=b0)
log['warning'] = result_code_string
log['result_code'] = result_code
- return G, log
- return G
+ return nx.from_numpy(G, type_as=M0), log
+ return nx.from_numpy(G, type_as=M0)
def emd2(a, b, M, processes=multiprocessing.cpu_count(),
@@ -303,20 +323,19 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
- M is the metric cost matrix
- a and b are the sample weights
- .. warning::
- Note that the M matrix needs to be a C-order numpy.array in float64
- format.
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
Uses the algorithm proposed in [1]_
Parameters
----------
- a : (ns,) numpy.ndarray, float64
+ a : (ns,) array-like, float64
Source histogram (uniform weight if empty list)
- b : (nt,) numpy.ndarray, float64
+ b : (nt,) array-like, float64
Target histogram (uniform weight if empty list)
- M : (ns,nt) numpy.ndarray, float64
- Loss matrix (c-order array with type float64)
+ M : (ns,nt) array-like, float64
+ Loss matrix (for numpy c-order array with type float64)
processes : int, optional (default=nb cpu)
Nb of processes used for multiple emd computation (not used on windows)
numItermax : int, optional (default=100000)
@@ -333,9 +352,9 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
Returns
-------
- W: float
+ W: float, array-like
Optimal transportation loss for the given parameters
- log: dictnp
+ log: dict
If input log is true, a dictionary containing dual
variables and exit status
@@ -367,12 +386,22 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
ot.bregman.sinkhorn : Entropic regularized OT
ot.optim.cg : General regularized OT"""
+ a, b, M = list_to_array(a, b, M)
+
+ a0, b0, M0 = a, b, M
+ nx = get_backend(M0, a0, b0)
+
+ # convert to numpy
+ M = nx.to_numpy(M)
+ a = nx.to_numpy(a)
+ b = nx.to_numpy(b)
+
a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ M = np.asarray(M, dtype=np.float64, order= 'C')
# problem with pikling Forks
- if sys.platform.endswith('win32'):
+ if sys.platform.endswith('win32') or not nx.__name__ == 'numpy':
processes = 1
# if empty array given then use uniform distributions
@@ -400,12 +429,15 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
result_code_string = check_result(result_code)
log = {}
+ G = nx.from_numpy(G, type_as=M0)
if return_matrix:
log['G'] = G
- log['u'] = u
- log['v'] = v
+ log['u'] = nx.from_numpy(u, type_as=a0)
+ log['v'] = nx.from_numpy(v, type_as=b0)
log['warning'] = result_code_string
log['result_code'] = result_code
+ cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0),
+ (a0,b0, M0), (log['u'], log['v'], G))
return [cost, log]
else:
def f(b):
@@ -418,6 +450,11 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
if np.any(~asel) or np.any(~bsel):
u, v = estimate_dual_null_weights(u, v, a, b, M)
+ G = nx.from_numpy(G, type_as=M0)
+ cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0),
+ (a0,b0, M0), (nx.from_numpy(u, type_as=a0),
+ nx.from_numpy(v, type_as=b0),G))
+
check_result(result_code)
return cost
@@ -637,6 +674,10 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
if b.ndim == 0 or len(b) == 0:
b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0]
+ # ensure that same mass
+ np.testing.assert_almost_equal(a.sum(0),b.sum(0),err_msg='a and b vector must have the same sum')
+ b=b*a.sum()/b.sum()
+
x_a_1d = x_a.reshape((-1,))
x_b_1d = x_b.reshape((-1,))
perm_a = np.argsort(x_a_1d)