summaryrefslogtreecommitdiff
path: root/ot/gpu/cudamat/cudamat/learn.py
diff options
context:
space:
mode:
authorLeo gautheron <gautheron@iv-cm-359.creatis.insa-lyon.fr>2017-04-20 12:12:15 +0200
committerLeo gautheron <gautheron@iv-cm-359.creatis.insa-lyon.fr>2017-04-20 12:12:15 +0200
commit16f51f971607efab2c73958d207c582b389406c8 (patch)
tree299a4f6f13faf8545d2144767e9a7791098aacf8 /ot/gpu/cudamat/cudamat/learn.py
parent48ec27d8e1c2599bd6d9015d15f4204b8116af28 (diff)
sinkhorn GPU implementation
Diffstat (limited to 'ot/gpu/cudamat/cudamat/learn.py')
-rw-r--r--ot/gpu/cudamat/cudamat/learn.py21
1 files changed, 21 insertions, 0 deletions
diff --git a/ot/gpu/cudamat/cudamat/learn.py b/ot/gpu/cudamat/cudamat/learn.py
new file mode 100644
index 0000000..741ca13
--- /dev/null
+++ b/ot/gpu/cudamat/cudamat/learn.py
@@ -0,0 +1,21 @@
+import os
+
+import ctypes as ct
+import numpy as np
+
+from cudamat import load_library, generate_exception
+
+_cudalearn = load_library('libcudalearn')
+
+_cudalearn.mult_by_sigmoid_deriv.restype = ct.c_int
+
+def mult_by_sigmoid_deriv(target, acts):
+ """
+ target = target * acts * (1 - acts)
+
+ Useful for doing backprop in neural networks with logistic units.
+ """
+
+ err_code = _cudalearn.mult_by_sigmoid_deriv(target.p_mat, acts.p_mat)
+ if err_code:
+ raise generate_exception(err_code)