summaryrefslogtreecommitdiff
path: root/ot/smooth.py
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2022-12-15 09:28:01 +0100
committerGitHub <noreply@github.com>2022-12-15 09:28:01 +0100
commit0411ea22a96f9c22af30156b45c16ef39ffb520d (patch)
tree7c131ad804d5b16a8c362c2fe296350a770400df /ot/smooth.py
parent8490196dcc982c492b7565e1ec4de5f75f006acf (diff)
[MRG] New API for OT solver (with pre-computed ground cost matrix) (#388)
* new API for OT solver * use itertools for product of parameters * add tests for result class * add tests for result class * add tests for result class last time? * add sinkhorn * make partial OT bckend compatible * add TV as unbalanced flavor * better tests * make smoth backend compatible and add l2 tregularizatio to solve * add reularizedd unbalanced * add test for more complex attibutes * add test for more complex attibutes * add generic unbalaned solver and implement it for ot.solve * add entropy to possible regularization * star of documentation for ot.solv * weird new pep8 * documenttaion for function ot.solve done * pep8 * Update ot/solvers.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * update release file * Apply suggestions from code review Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * add test NotImplemented * pep8 * pep8gcmp pep8! * compute kl in backend * debug tensorflow kl backend Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'ot/smooth.py')
-rw-r--r--ot/smooth.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/ot/smooth.py b/ot/smooth.py
index 6855005..8e0ef38 100644
--- a/ot/smooth.py
+++ b/ot/smooth.py
@@ -44,6 +44,7 @@ Original code from https://github.com/mblondel/smooth-ot/
import numpy as np
from scipy.optimize import minimize
+from .backend import get_backend
def projection_simplex(V, z=1, axis=None):
@@ -511,6 +512,8 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
"""
+ nx = get_backend(a, b, M)
+
if reg_type.lower() in ['l2', 'squaredl2']:
regul = SquaredL2(gamma=reg)
elif reg_type.lower() in ['entropic', 'negentropy', 'kl']:
@@ -518,15 +521,19 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
else:
raise NotImplementedError('Unknown regularization')
+ a0, b0, M0 = a, b, M
+ # convert to humpy
+ a, b, M = nx.to_numpy(a, b, M)
+
# solve dual
alpha, beta, res = solve_dual(a, b, M, regul, max_iter=numItermax,
tol=stopThr, verbose=verbose)
# reconstruct transport matrix
- G = get_plan_from_dual(alpha, beta, M, regul)
+ G = nx.from_numpy(get_plan_from_dual(alpha, beta, M, regul), type_as=M0)
if log:
- log = {'alpha': alpha, 'beta': beta, 'res': res}
+ log = {'alpha': nx.from_numpy(alpha, type_as=a0), 'beta': nx.from_numpy(beta, type_as=b0), 'res': res}
return G, log
else:
return G