summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorarolet <antoine.rolet@gmail.com>2017-07-21 13:34:09 +0900
committerarolet <antoine.rolet@gmail.com>2017-07-21 13:34:09 +0900
commitc1980a414c879dd1bc1d8904fd43426326741385 (patch)
tree2513d6ed5ee49d351652e0f989d83cb8298e72fa
parentdb2a70b1f5146d6374af57f4bea66ab95b202e77 (diff)
Added and passed tests for dual variables
-rw-r--r--ot/lp/EMD_wrapper.cpp2
-rw-r--r--ot/lp/emd_wrap.pyx4
-rw-r--r--test/test_emd.py28
3 files changed, 22 insertions, 12 deletions
diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp
index c6cbb04..0977e75 100644
--- a/ot/lp/EMD_wrapper.cpp
+++ b/ot/lp/EMD_wrapper.cpp
@@ -101,7 +101,7 @@ void EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
double flow = net.flow(a);
*cost += flow * (*(D+indI[i]*n2+indJ[j-n]));
*(G+indI[i]*n2+indJ[j-n]) = flow;
- *(alpha + indI[i]) = net.potential(i);
+ *(alpha + indI[i]) = -net.potential(i);
*(beta + indJ[j-n]) = net.potential(j);
}
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index 813596f..435a270 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -57,7 +57,7 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod
cdef int n1= M.shape[0]
cdef int n2= M.shape[1]
- cdef float cost=0
+ cdef double cost=0
cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2])
cdef np.ndarray[double, ndim=1, mode="c"] alpha=np.zeros(n1)
cdef np.ndarray[double, ndim=1, mode="c"] beta=np.zeros(n2)
@@ -116,7 +116,7 @@ def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mo
cdef int n1= M.shape[0]
cdef int n2= M.shape[1]
- cdef float cost=0
+ cdef double cost=0
cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2])
cdef np.ndarray[double, ndim = 1, mode = "c"] alpha = np.zeros([n1])
diff --git a/test/test_emd.py b/test/test_emd.py
index 4757cd1..3bf6fa2 100644
--- a/test/test_emd.py
+++ b/test/test_emd.py
@@ -2,7 +2,6 @@
# -*- coding: utf-8 -*-
import numpy as np
-import pylab as pl
import ot
from ot.datasets import get_1D_gauss as gauss
@@ -16,8 +15,6 @@ m=6000 # nb bins
mean1 = 1000
mean2 = 1100
-tol = 1e-6
-
# bin positions
x=np.arange(n,dtype=np.float64)
y=np.arange(m,dtype=np.float64)
@@ -38,10 +35,11 @@ print('Computing {} EMD '.format(1))
# emd loss 1 proc
ot.tic()
-G = ot.emd(a,b,M)
+G, alpha, beta = ot.emd(a,b,M, dual_variables=True)
ot.toc('1 proc : {} s')
cost1 = (G * M).sum()
+cost_dual = np.vdot(a, alpha) + np.vdot(b, beta)
# emd loss 1 proc
ot.tic()
@@ -49,11 +47,23 @@ cost_emd2 = ot.emd2(a,b,M)
ot.toc('1 proc : {} s')
ot.tic()
-G = ot.emd(b, a, np.ascontiguousarray(M.T))
+G2 = ot.emd(b, a, np.ascontiguousarray(M.T))
ot.toc('1 proc : {} s')
-cost2 = (G * M.T).sum()
+cost2 = (G2 * M.T).sum()
+
+M_reduced = M - alpha.reshape(-1,1) - beta.reshape(1, -1)
+
+# Check that both cost computations are equivalent
+np.testing.assert_almost_equal(cost1, cost_emd2)
+# Check that dual and primal cost are equal
+np.testing.assert_almost_equal(cost1, cost_dual)
+# Check symmetry
+np.testing.assert_almost_equal(cost1, cost2)
+# Check with closed-form solution for gaussians
+np.testing.assert_almost_equal(cost1, np.abs(mean1-mean2))
+
+[ind1, ind2] = np.nonzero(G)
-assert np.abs(cost1-cost_emd2)/np.abs(cost1) < tol
-assert np.abs(cost1-cost2)/np.abs(cost1) < tol
-assert np.abs(cost1-np.abs(mean1-mean2))/np.abs(cost1) < tol
+# Check that reduced cost is zero on transport arcs
+np.testing.assert_array_almost_equal((M - alpha.reshape(-1, 1) - beta.reshape(1, -1))[ind1, ind2], np.zeros(ind1.size)) \ No newline at end of file