summaryrefslogtreecommitdiff
path: root/ot/smooth.py
diff options
context:
space:
mode:
authorGard Spreemann <gspr@nonempty.org>2023-06-14 16:52:13 +0200
committerGard Spreemann <gspr@nonempty.org>2023-06-14 16:52:13 +0200
commit2b51c7bfcf54d7e17ac7c2514f54408543cbe126 (patch)
treebaf00cc603ceabad00626259eec898e6747d016c /ot/smooth.py
parenta49f648b0b07737f7ef315fb83d8f78871780281 (diff)
parent96788a3fe5601e4c3f49b592aa0d9c034247862e (diff)
Merge branch 'dfsg/latest'
Diffstat (limited to 'ot/smooth.py')
-rw-r--r--ot/smooth.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/ot/smooth.py b/ot/smooth.py
index 6855005..8e0ef38 100644
--- a/ot/smooth.py
+++ b/ot/smooth.py
@@ -44,6 +44,7 @@ Original code from https://github.com/mblondel/smooth-ot/
import numpy as np
from scipy.optimize import minimize
+from .backend import get_backend
def projection_simplex(V, z=1, axis=None):
@@ -511,6 +512,8 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
"""
+ nx = get_backend(a, b, M)
+
if reg_type.lower() in ['l2', 'squaredl2']:
regul = SquaredL2(gamma=reg)
elif reg_type.lower() in ['entropic', 'negentropy', 'kl']:
@@ -518,15 +521,19 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
else:
raise NotImplementedError('Unknown regularization')
+ a0, b0, M0 = a, b, M
+ # convert to humpy
+ a, b, M = nx.to_numpy(a, b, M)
+
# solve dual
alpha, beta, res = solve_dual(a, b, M, regul, max_iter=numItermax,
tol=stopThr, verbose=verbose)
# reconstruct transport matrix
- G = get_plan_from_dual(alpha, beta, M, regul)
+ G = nx.from_numpy(get_plan_from_dual(alpha, beta, M, regul), type_as=M0)
if log:
- log = {'alpha': alpha, 'beta': beta, 'res': res}
+ log = {'alpha': nx.from_numpy(alpha, type_as=a0), 'beta': nx.from_numpy(beta, type_as=b0), 'res': res}
return G, log
else:
return G