From f089a3cbc27c30ba9416ea1659c2fdbac1857146 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Fri, 16 Feb 2018 13:53:34 +0100 Subject: better pep8 but not solved --- ot/gpu/__init__.py | 2 +- ot/gpu/da.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) (limited to 'ot/gpu') diff --git a/ot/gpu/__init__.py b/ot/gpu/__init__.py index c8f9433..a2fdd3d 100644 --- a/ot/gpu/__init__.py +++ b/ot/gpu/__init__.py @@ -5,7 +5,7 @@ from . import da from .bregman import sinkhorn # Author: Remi Flamary -# Leo Gautheron +# Leo Gautheron # # License: MIT License diff --git a/ot/gpu/da.py b/ot/gpu/da.py index 05c580f..71a485a 100644 --- a/ot/gpu/da.py +++ b/ot/gpu/da.py @@ -188,6 +188,7 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M_GPU, reg, eta=0.1, numItermax=10, class OTDA_GPU(OTDA): + def normalizeM(self, norm): if norm == "median": self.M_GPU.divide(float(np.median(self.M_GPU.asarray()))) @@ -204,6 +205,7 @@ class OTDA_GPU(OTDA): class OTDA_sinkhorn(OTDA_GPU): + def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs): cudamat.init() xs = np.asarray(xs, dtype=np.float64) @@ -228,6 +230,7 @@ class OTDA_sinkhorn(OTDA_GPU): class OTDA_lpl1(OTDA_GPU): + def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None, **kwargs): cudamat.init() -- cgit v1.2.3