summaryrefslogtreecommitdiff
path: root/ot/gpu/bregman.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/gpu/bregman.py')
-rw-r--r--ot/gpu/bregman.py18
1 files changed, 12 insertions, 6 deletions
diff --git a/ot/gpu/bregman.py b/ot/gpu/bregman.py
index 2c3e317..47939c4 100644
--- a/ot/gpu/bregman.py
+++ b/ot/gpu/bregman.py
@@ -3,13 +3,18 @@
Bregman projections for regularized OT with GPU
"""
+# Author: Remi Flamary <remi.flamary@unice.fr>
+# Leo Gautheron <https://github.com/aje>
+#
+# License: MIT License
+
import numpy as np
import cudamat
def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
log=False, returnAsGPU=False):
- """
+ r"""
Solve the entropic regularization optimal transport problem on GPU
The function solves the following optimization problem:
@@ -82,7 +87,7 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
- """
+ """
# init data
Nini = len(a)
Nfin = len(b)
@@ -92,11 +97,11 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
# we assume that no distances are null except those of the diagonal of
# distances
- u = (np.ones(Nini)/Nini).reshape((Nini, 1))
+ u = (np.ones(Nini) / Nini).reshape((Nini, 1))
u_GPU = cudamat.CUDAMatrix(u)
a_GPU = cudamat.CUDAMatrix(a.reshape((Nini, 1)))
ones_GPU = cudamat.empty(u_GPU.shape).assign(1)
- v = (np.ones(Nfin)/Nfin).reshape((Nfin, 1))
+ v = (np.ones(Nfin) / Nfin).reshape((Nfin, 1))
v_GPU = cudamat.CUDAMatrix(v)
b_GPU = cudamat.CUDAMatrix(b.reshape((Nfin, 1)))
@@ -121,7 +126,7 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
ones_GPU.divide(Kp_GPU.dot(v_GPU), target=u_GPU)
if (np.any(KtransposeU_GPU.asarray() == 0) or
- not u_GPU.allfinite() or not v_GPU.allfinite()):
+ not u_GPU.allfinite() or not v_GPU.allfinite()):
# we have reached the machine precision
# come back to previous solution and quit loop
print('Warning: numerical errors at iteration', cpt)
@@ -142,7 +147,8 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
if verbose:
if cpt % 200 == 0:
- print('{:5s}|{:12s}'.format('It.', 'Err')+'\n'+'-'*19)
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(cpt, err))
cpt += 1
if log: