summaryrefslogtreecommitdiff
path: root/ot/gpu
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-02-16 13:53:34 +0100
committerRémi Flamary <remi.flamary@gmail.com>2018-02-16 13:53:34 +0100
commitf089a3cbc27c30ba9416ea1659c2fdbac1857146 (patch)
tree04905ed08da51190f15faa7aff67acb308ee583f /ot/gpu
parent4a585de94109102c89bcd7ad43e35772e1027cd2 (diff)
better pep8 but not solved
Diffstat (limited to 'ot/gpu')
-rw-r--r--ot/gpu/__init__.py2
-rw-r--r--ot/gpu/da.py3
2 files changed, 4 insertions, 1 deletions
diff --git a/ot/gpu/__init__.py b/ot/gpu/__init__.py
index c8f9433..a2fdd3d 100644
--- a/ot/gpu/__init__.py
+++ b/ot/gpu/__init__.py
@@ -5,7 +5,7 @@ from . import da
from .bregman import sinkhorn
# Author: Remi Flamary <remi.flamary@unice.fr>
-# Leo Gautheron <https://github.com/aje>
+# Leo Gautheron <https://github.com/aje>
#
# License: MIT License
diff --git a/ot/gpu/da.py b/ot/gpu/da.py
index 05c580f..71a485a 100644
--- a/ot/gpu/da.py
+++ b/ot/gpu/da.py
@@ -188,6 +188,7 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M_GPU, reg, eta=0.1, numItermax=10,
class OTDA_GPU(OTDA):
+
def normalizeM(self, norm):
if norm == "median":
self.M_GPU.divide(float(np.median(self.M_GPU.asarray())))
@@ -204,6 +205,7 @@ class OTDA_GPU(OTDA):
class OTDA_sinkhorn(OTDA_GPU):
+
def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs):
cudamat.init()
xs = np.asarray(xs, dtype=np.float64)
@@ -228,6 +230,7 @@ class OTDA_sinkhorn(OTDA_GPU):
class OTDA_lpl1(OTDA_GPU):
+
def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
**kwargs):
cudamat.init()