summaryrefslogtreecommitdiff
path: root/ot/gpu/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/gpu/__init__.py')
-rw-r--r--ot/gpu/__init__.py22
1 files changed, 21 insertions, 1 deletions
diff --git a/ot/gpu/__init__.py b/ot/gpu/__init__.py
index de4825d..9de2c40 100644
--- a/ot/gpu/__init__.py
+++ b/ot/gpu/__init__.py
@@ -1,8 +1,28 @@
# -*- coding: utf-8 -*-
+"""
+
+
+This module implement GPU ilmplementation for several OT solvers and utility
+functions. The GPU backend in handled by `cupy
+<https://cupy.chainer.org/>`_.
+
+By default, the functions in this module accept and return numpy arrays
+in order to proide drop-in replacement for the other POT function but
+the transfer between CPU en GPU comes with a significant overhead.
+
+In order to get the best erformances, we recommend to given only cupy
+arrays to the functions and desactivate the conversion to numpy of the
+result of the function with parameter ``to_numpy=False``.
+
+
+
+
+"""
from . import bregman
from . import da
from .bregman import sinkhorn
+from .da
from . import utils
from .utils import dist, to_gpu, to_np
@@ -13,4 +33,4 @@ from .utils import dist, to_gpu, to_np
#
# License: MIT License
-__all__ = ["utils", "dist", "sinkhorn"]
+__all__ = ["utils", "dist", "sinkhorn", 'bregman', 'da', 'to_gpu', 'to_np']