summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
Diffstat (limited to 'ot')
-rw-r--r--ot/optim.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/ot/optim.py b/ot/optim.py
index 58e5596..b15c77b 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -93,7 +93,7 @@ def line_search_armijo(
fc[0] += 1
alpha10 = nx.from_numpy(alpha1)
fval = f(xk0 + alpha10 * pk0, *args)
- if type(fval) is float:
+ if isinstance(fval, float):
# prevent bug from nx.to_numpy that can look for .cpu or .gpu
return fval
else:
@@ -101,7 +101,7 @@ def line_search_armijo(
if old_fval is None:
phi0 = phi(0.)
- elif type(old_fval) is float:
+ elif isinstance(old_fval, float):
# prevent bug from nx.to_numpy that can look for .cpu or .gpu
phi0 = old_fval
else: