summaryrefslogtreecommitdiff
path: root/ot/lp
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2021-04-21 17:12:29 +0200
committerGitHub <noreply@github.com>2021-04-21 17:12:29 +0200
commit0d995011b19b243bc980588cd98786b7c41a0509 (patch)
tree6d45d2025c63f2633380c4b0339aa0d3a2f2a6f0 /ot/lp
parentcd3ce6140d7a2dbe2bcf05927a8dd8289f4ce9e2 (diff)
[MRG] Fixes issue #239 (deprecated numpy types) (#244)
* remove warning numpy int? * use long long * stoupid mistake * cleanup double test run in PR from local branch
Diffstat (limited to 'ot/lp')
-rw-r--r--ot/lp/emd_wrap.pyx6
1 files changed, 2 insertions, 4 deletions
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index d79d0ca..de9a700 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -97,8 +97,6 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([0, 0])
cdef np.ndarray[double, ndim=1, mode="c"] Gv=np.zeros(0)
- cdef np.ndarray[long, ndim=1, mode="c"] iG=np.zeros(0,dtype=np.int)
- cdef np.ndarray[long, ndim=1, mode="c"] jG=np.zeros(0,dtype=np.int)
if not len(a):
a=np.ones((n1,))/n1
@@ -169,8 +167,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
cdef np.ndarray[double, ndim=1, mode="c"] G = np.zeros((n + m - 1, ),
dtype=np.float64)
- cdef np.ndarray[long, ndim=2, mode="c"] indices = np.zeros((n + m - 1, 2),
- dtype=np.int)
+ cdef np.ndarray[long long, ndim=2, mode="c"] indices = np.zeros((n + m - 1, 2),
+ dtype=np.int64)
cdef Py_ssize_t cur_idx = 0
while True:
if metric == 'sqeuclidean':