summaryrefslogtreecommitdiff
path: root/ot/gpu/da.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-02-16 13:53:34 +0100
committerRémi Flamary <remi.flamary@gmail.com>2018-02-16 13:53:34 +0100
commitf089a3cbc27c30ba9416ea1659c2fdbac1857146 (patch)
tree04905ed08da51190f15faa7aff67acb308ee583f /ot/gpu/da.py
parent4a585de94109102c89bcd7ad43e35772e1027cd2 (diff)
better pep8 but not solved
Diffstat (limited to 'ot/gpu/da.py')
-rw-r--r--ot/gpu/da.py3
1 files changed, 3 insertions, 0 deletions
diff --git a/ot/gpu/da.py b/ot/gpu/da.py
index 05c580f..71a485a 100644
--- a/ot/gpu/da.py
+++ b/ot/gpu/da.py
@@ -188,6 +188,7 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M_GPU, reg, eta=0.1, numItermax=10,
class OTDA_GPU(OTDA):
+
def normalizeM(self, norm):
if norm == "median":
self.M_GPU.divide(float(np.median(self.M_GPU.asarray())))
@@ -204,6 +205,7 @@ class OTDA_GPU(OTDA):
class OTDA_sinkhorn(OTDA_GPU):
+
def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs):
cudamat.init()
xs = np.asarray(xs, dtype=np.float64)
@@ -228,6 +230,7 @@ class OTDA_sinkhorn(OTDA_GPU):
class OTDA_lpl1(OTDA_GPU):
+
def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
**kwargs):
cudamat.init()