diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2022-12-15 09:28:01 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-12-15 09:28:01 +0100 |
commit | 0411ea22a96f9c22af30156b45c16ef39ffb520d (patch) | |
tree | 7c131ad804d5b16a8c362c2fe296350a770400df /ot/smooth.py | |
parent | 8490196dcc982c492b7565e1ec4de5f75f006acf (diff) |
[MRG] New API for OT solver (with pre-computed ground cost matrix) (#388)
* new API for OT solver
* use itertools for product of parameters
* add tests for result class
* add tests for result class
* add tests for result class last time?
* add sinkhorn
* make partial OT bckend compatible
* add TV as unbalanced flavor
* better tests
* make smoth backend compatible and add l2 tregularizatio to solve
* add reularizedd unbalanced
* add test for more complex attibutes
* add test for more complex attibutes
* add generic unbalaned solver and implement it for ot.solve
* add entropy to possible regularization
* star of documentation for ot.solv
* weird new pep8
* documenttaion for function ot.solve done
* pep8
* Update ot/solvers.py
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* update release file
* Apply suggestions from code review
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* add test NotImplemented
* pep8
* pep8gcmp pep8!
* compute kl in backend
* debug tensorflow kl backend
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'ot/smooth.py')
-rw-r--r-- | ot/smooth.py | 11 |
1 files changed, 9 insertions, 2 deletions
diff --git a/ot/smooth.py b/ot/smooth.py index 6855005..8e0ef38 100644 --- a/ot/smooth.py +++ b/ot/smooth.py @@ -44,6 +44,7 @@ Original code from https://github.com/mblondel/smooth-ot/ import numpy as np from scipy.optimize import minimize +from .backend import get_backend def projection_simplex(V, z=1, axis=None): @@ -511,6 +512,8 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, """ + nx = get_backend(a, b, M) + if reg_type.lower() in ['l2', 'squaredl2']: regul = SquaredL2(gamma=reg) elif reg_type.lower() in ['entropic', 'negentropy', 'kl']: @@ -518,15 +521,19 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, else: raise NotImplementedError('Unknown regularization') + a0, b0, M0 = a, b, M + # convert to humpy + a, b, M = nx.to_numpy(a, b, M) + # solve dual alpha, beta, res = solve_dual(a, b, M, regul, max_iter=numItermax, tol=stopThr, verbose=verbose) # reconstruct transport matrix - G = get_plan_from_dual(alpha, beta, M, regul) + G = nx.from_numpy(get_plan_from_dual(alpha, beta, M, regul), type_as=M0) if log: - log = {'alpha': alpha, 'beta': beta, 'res': res} + log = {'alpha': nx.from_numpy(alpha, type_as=a0), 'beta': nx.from_numpy(beta, type_as=b0), 'res': res} return G, log else: return G |