summaryrefslogtreecommitdiff
path: root/ot/gpu/cudamat/examples/util.py
diff options
context:
space:
mode:
authorLeo gautheron <gautheron@iv-cm-359.creatis.insa-lyon.fr>2017-04-20 12:12:15 +0200
committerLeo gautheron <gautheron@iv-cm-359.creatis.insa-lyon.fr>2017-04-20 12:12:15 +0200
commit16f51f971607efab2c73958d207c582b389406c8 (patch)
tree299a4f6f13faf8545d2144767e9a7791098aacf8 /ot/gpu/cudamat/examples/util.py
parent48ec27d8e1c2599bd6d9015d15f4204b8116af28 (diff)
sinkhorn GPU implementation
Diffstat (limited to 'ot/gpu/cudamat/examples/util.py')
-rw-r--r--ot/gpu/cudamat/examples/util.py22
1 files changed, 22 insertions, 0 deletions
diff --git a/ot/gpu/cudamat/examples/util.py b/ot/gpu/cudamat/examples/util.py
new file mode 100644
index 0000000..79ceead
--- /dev/null
+++ b/ot/gpu/cudamat/examples/util.py
@@ -0,0 +1,22 @@
+from __future__ import division
+import gzip
+try: import cPickle as pickle
+except: import pickle
+
+def save(fname, var_list, source_dict):
+ var_list = [var.strip() for var in var_list.split() if len(var.strip())>0]
+ fo = gzip.GzipFile(fname, 'wb')
+ pickle.dump(var_list, fo)
+ for var in var_list:
+ pickle.dump(source_dict[var], fo, protocol=2)
+ fo.close()
+
+def load(fname, target_dict, verbose = True):
+ fo = gzip.GzipFile(fname, 'rb')
+ var_list = pickle.load(fo)
+ if verbose:
+ print(var_list)
+ for var in var_list:
+ target_dict[var] = pickle.load(fo)
+ fo.close()
+