summaryrefslogtreecommitdiff
path: root/ot/gpu
diff options
context:
space:
mode:
Diffstat (limited to 'ot/gpu')
-rw-r--r--ot/gpu/__init__.py12
-rw-r--r--ot/gpu/bregman.py12
-rw-r--r--ot/gpu/da.py2
3 files changed, 18 insertions, 8 deletions
diff --git a/ot/gpu/__init__.py b/ot/gpu/__init__.py
index 7478fb9..12db605 100644
--- a/ot/gpu/__init__.py
+++ b/ot/gpu/__init__.py
@@ -7,7 +7,13 @@ The GPU backend in handled by `cupy
<https://cupy.chainer.org/>`_.
.. warning::
- Note that by default the module is not import in :mod:`ot`. In order to
+ This module is now deprecated and will be removed in future releases. POT
+ now privides a backend mechanism that allows for solving prolem on GPU wth
+ the pytorch backend.
+
+
+.. warning::
+ Note that by default the module is not imported in :mod:`ot`. In order to
use it you need to explicitely import :mod:`ot.gpu` .
By default, the functions in this module accept and return numpy arrays
@@ -25,6 +31,8 @@ result of the function with parameter ``to_numpy=False``.
#
# License: MIT License
+import warnings
+
from . import bregman
from . import da
from .bregman import sinkhorn
@@ -34,7 +42,7 @@ from . import utils
from .utils import dist, to_gpu, to_np
-
+warnings.warn('This module is deprecated and will be removed in the next minor release of POT', category=DeprecationWarning)
__all__ = ["utils", "dist", "sinkhorn",
diff --git a/ot/gpu/bregman.py b/ot/gpu/bregman.py
index 2e2df83..76af00e 100644
--- a/ot/gpu/bregman.py
+++ b/ot/gpu/bregman.py
@@ -15,7 +15,7 @@ from . import utils
def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
verbose=False, log=False, to_numpy=True, **kwargs):
- """
+ r"""
Solve the entropic regularization optimal transport on GPU
If the input matrix are in numpy format, they will be uploaded to the
@@ -54,7 +54,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -148,13 +148,15 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
# we can speed up the process by checking for the error only all
# the 10th iterations
if nbb:
- err = np.sum((u - uprev)**2) / np.sum((u)**2) + \
- np.sum((v - vprev)**2) / np.sum((v)**2)
+ err = np.sqrt(
+ np.sum((u - uprev)**2) / np.sum((u)**2)
+ + np.sum((v - vprev)**2) / np.sum((v)**2)
+ )
else:
# compute right marginal tmp2= (diag(u)Kdiag(v))^T1
tmp2 = np.sum(u[:, None] * K * v[None, :], 0)
#tmp2=np.einsum('i,ij,j->j', u, K, v)
- err = np.linalg.norm(tmp2 - b)**2 # violation of marginal
+ err = np.linalg.norm(tmp2 - b) # violation of marginal
if log:
log['err'].append(err)
diff --git a/ot/gpu/da.py b/ot/gpu/da.py
index 4a98038..7adb830 100644
--- a/ot/gpu/da.py
+++ b/ot/gpu/da.py
@@ -120,7 +120,7 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
labels_a2 = cp.asnumpy(labels_a)
classes = npp.unique(labels_a2)
for c in classes:
- idxc, = utils.to_gpu(npp.where(labels_a2 == c))
+ idxc = utils.to_gpu(*npp.where(labels_a2 == c))
indices_labels.append(idxc)
W = np.zeros(M.shape)