summaryrefslogtreecommitdiff
path: root/ot/gpu/__init__.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-09-28 15:32:53 +0200
committerRémi Flamary <remi.flamary@gmail.com>2018-09-28 15:32:53 +0200
commitc18f73739fdd8e8ca0ef88dddcf2ba039a85dacf (patch)
tree9f00d6baa9d68457917190cdad807c027fc85b56 /ot/gpu/__init__.py
parentc531fba46d9ea703a8a5ae270efe44466d54a593 (diff)
parent8f6c45534dcdd6e8f80a31205be28481fc79c533 (diff)
merge with new gpu stuff
Diffstat (limited to 'ot/gpu/__init__.py')
-rw-r--r--ot/gpu/__init__.py34
1 files changed, 27 insertions, 7 deletions
diff --git a/ot/gpu/__init__.py b/ot/gpu/__init__.py
index ed6dcc4..213578c 100644
--- a/ot/gpu/__init__.py
+++ b/ot/gpu/__init__.py
@@ -1,17 +1,37 @@
# -*- coding: utf-8 -*-
+"""
-from . import bregman
-from . import da
-from .bregman import sinkhorn
+This module provides GPU implementation for several OT solvers and utility
+functions. The GPU backend in handled by `cupy
+<https://cupy.chainer.org/>`_.
+
+By default, the functions in this module accept and return numpy arrays
+in order to proide drop-in replacement for the other POT function but
+the transfer between CPU en GPU comes with a significant overhead.
+
+In order to get the best erformances, we recommend to given only cupy
+arrays to the functions and desactivate the conversion to numpy of the
+result of the function with parameter ``to_numpy=False``.
+
+"""
# Author: Remi Flamary <remi.flamary@unice.fr>
# Leo Gautheron <https://github.com/aje>
#
# License: MIT License
-import warnings
+from . import bregman
+from . import da
+from .bregman import sinkhorn
+from .da import sinkhorn_lpl1_mm
+
+from . import utils
+from .utils import dist, to_gpu, to_np
+
+
+
+
-warnings.warn("the ot.gpu module is deprecated because cudamat in no longer maintained", DeprecationWarning,
- stacklevel=2)
+__all__ = ["utils", "dist", "sinkhorn",
+ "sinkhorn_lpl1_mm", 'bregman', 'da', 'to_gpu', 'to_np']
-__all__ = ["bregman", "da", "sinkhorn"]