summaryrefslogtreecommitdiff
path: root/ot/gpu/da.py
diff options
context:
space:
mode:
authorLeo gautheron <gautheron@iv-cm-359.creatis.insa-lyon.fr>2017-04-20 13:51:30 +0200
committerLeo gautheron <gautheron@iv-cm-359.creatis.insa-lyon.fr>2017-04-20 13:51:30 +0200
commit67f09ad4089594457cb00702dd3aa8a6a94ec2eb (patch)
treed458e5661f8ae9dd1dfe8127003ac4141f74b7d0 /ot/gpu/da.py
parent16f51f971607efab2c73958d207c582b389406c8 (diff)
changes from feedback
Diffstat (limited to 'ot/gpu/da.py')
-rw-r--r--ot/gpu/da.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/ot/gpu/da.py b/ot/gpu/da.py
index 7d04b5b..8b45772 100644
--- a/ot/gpu/da.py
+++ b/ot/gpu/da.py
@@ -6,7 +6,7 @@ Domain adaptation with optimal transport and GPU
import numpy as np
from ..utils import unif
from ..da import OTDA
-from .bregman import sinkhornGPU
+from .bregman import sinkhorn
def pairwiseEuclideanGPU(a, b, returnAsGPU=False, squared=False, cudamat=None):
@@ -46,7 +46,7 @@ def pairwiseEuclideanGPU(a, b, returnAsGPU=False, squared=False, cudamat=None):
class OTDA_sinkhorn_GPU(OTDA):
- def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None):
+ def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs):
import cudamat
cudamat.init()
xs = np.asarray(xs, dtype=np.float64)
@@ -77,5 +77,5 @@ class OTDA_sinkhorn_GPU(OTDA):
M = np.log(1 + np.log(1 + self.M_GPU.asarray()))
self.M_GPU = cudamat.CUDAMatrix(M)
- self.G = sinkhornGPU(ws, wt, self.M_GPU, reg, cudamat=cudamat)
+ self.G = sinkhorn(ws, wt, self.M_GPU, reg, cudamat=cudamat)
self.computed = True