diff options
Diffstat (limited to 'ot/gpu/da.py')
-rw-r--r-- | ot/gpu/da.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/ot/gpu/da.py b/ot/gpu/da.py index 7d04b5b..8b45772 100644 --- a/ot/gpu/da.py +++ b/ot/gpu/da.py @@ -6,7 +6,7 @@ Domain adaptation with optimal transport and GPU import numpy as np from ..utils import unif from ..da import OTDA -from .bregman import sinkhornGPU +from .bregman import sinkhorn def pairwiseEuclideanGPU(a, b, returnAsGPU=False, squared=False, cudamat=None): @@ -46,7 +46,7 @@ def pairwiseEuclideanGPU(a, b, returnAsGPU=False, squared=False, cudamat=None): class OTDA_sinkhorn_GPU(OTDA): - def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None): + def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs): import cudamat cudamat.init() xs = np.asarray(xs, dtype=np.float64) @@ -77,5 +77,5 @@ class OTDA_sinkhorn_GPU(OTDA): M = np.log(1 + np.log(1 + self.M_GPU.asarray())) self.M_GPU = cudamat.CUDAMatrix(M) - self.G = sinkhornGPU(ws, wt, self.M_GPU, reg, cudamat=cudamat) + self.G = sinkhorn(ws, wt, self.M_GPU, reg, cudamat=cudamat) self.computed = True |