summaryrefslogtreecommitdiff
path: root/ot/gpu/__init__.py
diff options
context:
space:
mode:
authorGard Spreemann <gspr@nonempty.org>2020-01-20 14:07:53 +0100
committerGard Spreemann <gspr@nonempty.org>2020-01-20 14:07:53 +0100
commitbdfb24ff37ea777d6e266b145047cd4e281ebac3 (patch)
tree00cbac5f3dc25a4ee76164828abd72c1cbab37cc /ot/gpu/__init__.py
parentabc441b00f0fe2fa4ef0efc4e1aa67b27cca9a13 (diff)
parent5e70a77fbb2feec513f21c9ef65dcc535329ace6 (diff)
Merge tag '0.6.0' into debian/sid
Diffstat (limited to 'ot/gpu/__init__.py')
-rw-r--r--ot/gpu/__init__.py41
1 files changed, 41 insertions, 0 deletions
diff --git a/ot/gpu/__init__.py b/ot/gpu/__init__.py
new file mode 100644
index 0000000..1ab95bb
--- /dev/null
+++ b/ot/gpu/__init__.py
@@ -0,0 +1,41 @@
+# -*- coding: utf-8 -*-
+"""
+
+This module provides GPU implementation for several OT solvers and utility
+functions. The GPU backend in handled by `cupy
+<https://cupy.chainer.org/>`_.
+
+.. warning::
+ Note that by default the module is not import in :mod:`ot`. In order to
+ use it you need to explicitely import :mod:`ot.gpu` .
+
+By default, the functions in this module accept and return numpy arrays
+in order to proide drop-in replacement for the other POT function but
+the transfer between CPU en GPU comes with a significant overhead.
+
+In order to get the best performances, we recommend to give only cupy
+arrays to the functions and desactivate the conversion to numpy of the
+result of the function with parameter ``to_numpy=False``.
+
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+# Leo Gautheron <https://github.com/aje>
+#
+# License: MIT License
+
+from . import bregman
+from . import da
+from .bregman import sinkhorn
+from .da import sinkhorn_lpl1_mm
+
+from . import utils
+from .utils import dist, to_gpu, to_np
+
+
+
+
+
+__all__ = ["utils", "dist", "sinkhorn",
+ "sinkhorn_lpl1_mm", 'bregman', 'da', 'to_gpu', 'to_np']
+