diff options
Diffstat (limited to 'ot/gpu')
-rw-r--r-- | ot/gpu/__init__.py | 12 | ||||
-rw-r--r-- | ot/gpu/bregman.py | 12 | ||||
-rw-r--r-- | ot/gpu/da.py | 2 |
3 files changed, 18 insertions, 8 deletions
diff --git a/ot/gpu/__init__.py b/ot/gpu/__init__.py index 7478fb9..12db605 100644 --- a/ot/gpu/__init__.py +++ b/ot/gpu/__init__.py @@ -7,7 +7,13 @@ The GPU backend in handled by `cupy <https://cupy.chainer.org/>`_. .. warning:: - Note that by default the module is not import in :mod:`ot`. In order to + This module is now deprecated and will be removed in future releases. POT + now privides a backend mechanism that allows for solving prolem on GPU wth + the pytorch backend. + + +.. warning:: + Note that by default the module is not imported in :mod:`ot`. In order to use it you need to explicitely import :mod:`ot.gpu` . By default, the functions in this module accept and return numpy arrays @@ -25,6 +31,8 @@ result of the function with parameter ``to_numpy=False``. # # License: MIT License +import warnings + from . import bregman from . import da from .bregman import sinkhorn @@ -34,7 +42,7 @@ from . import utils from .utils import dist, to_gpu, to_np - +warnings.warn('This module is deprecated and will be removed in the next minor release of POT', category=DeprecationWarning) __all__ = ["utils", "dist", "sinkhorn", diff --git a/ot/gpu/bregman.py b/ot/gpu/bregman.py index 2e2df83..76af00e 100644 --- a/ot/gpu/bregman.py +++ b/ot/gpu/bregman.py @@ -15,7 +15,7 @@ from . import utils def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, log=False, to_numpy=True, **kwargs): - """ + r""" Solve the entropic regularization optimal transport on GPU If the input matrix are in numpy format, they will be uploaded to the @@ -54,7 +54,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -148,13 +148,15 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, # we can speed up the process by checking for the error only all # the 10th iterations if nbb: - err = np.sum((u - uprev)**2) / np.sum((u)**2) + \ - np.sum((v - vprev)**2) / np.sum((v)**2) + err = np.sqrt( + np.sum((u - uprev)**2) / np.sum((u)**2) + + np.sum((v - vprev)**2) / np.sum((v)**2) + ) else: # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 tmp2 = np.sum(u[:, None] * K * v[None, :], 0) #tmp2=np.einsum('i,ij,j->j', u, K, v) - err = np.linalg.norm(tmp2 - b)**2 # violation of marginal + err = np.linalg.norm(tmp2 - b) # violation of marginal if log: log['err'].append(err) diff --git a/ot/gpu/da.py b/ot/gpu/da.py index 4a98038..7adb830 100644 --- a/ot/gpu/da.py +++ b/ot/gpu/da.py @@ -120,7 +120,7 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, labels_a2 = cp.asnumpy(labels_a) classes = npp.unique(labels_a2) for c in classes: - idxc, = utils.to_gpu(npp.where(labels_a2 == c)) + idxc = utils.to_gpu(*npp.where(labels_a2 == c)) indices_labels.append(idxc) W = np.zeros(M.shape) |