From 362a7f8fa20cf7ae6f2e36d7e47c7ca9f81d3c51 Mon Sep 17 00:00:00 2001 From: Romain Tavenard Date: Thu, 27 Jun 2019 13:29:19 +0200 Subject: Added RT as a contributor + "optimized" Cython math operations --- README.md | 1 + ot/lp/emd_wrap.pyx | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index d24d8b9..84148f8 100644 --- a/README.md +++ b/README.md @@ -167,6 +167,7 @@ The contributors to this library are: * [Alain Rakotomamonjy](https://sites.google.com/site/alainrakotomamonjy/home) * [Vayer Titouan](https://tvayer.github.io/) * [Hicham Janati](https://hichamjanati.github.io/) (Unbalanced OT) +* [Romain Tavenard](https://rtavenar.github.io/) (1d Wasserstein) This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages): diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 42b848f..8a4aec9 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -156,11 +156,11 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, cdef int cur_idx = 0 while i < n and j < m: if metric == 'sqeuclidean': - m_ij = (u[i] - v[j]) ** 2 + m_ij = (u[i] - v[j]) * (u[i] - v[j]) elif metric == 'cityblock' or metric == 'euclidean': - m_ij = abs(u[i] - v[j]) + m_ij = math.fabs(u[i] - v[j]) elif metric == 'minkowski': - m_ij = math.pow(abs(u[i] - v[j]), p) + m_ij = math.pow(math.fabs(u[i] - v[j]), p) else: m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)), metric=metric)[0, 0] -- cgit v1.2.3