summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ot/__init__.py6
-rw-r--r--ot/emd/__init__.py3
-rw-r--r--ot/lp/EMD.h (renamed from ot/emd/EMD.h)0
-rw-r--r--ot/lp/EMD_wrap.cpp (renamed from ot/emd/EMD_wrap.cpp)0
-rw-r--r--ot/lp/__init__.py3
-rw-r--r--ot/lp/core.h (renamed from ot/emd/core.h)0
-rw-r--r--ot/lp/emd.cpp (renamed from ot/emd/emd.cpp)0
-rw-r--r--ot/lp/emd.pyx (renamed from ot/emd/emd.pyx)0
-rw-r--r--ot/lp/full_bipartitegraph.h (renamed from ot/emd/full_bipartitegraph.h)0
-rw-r--r--ot/lp/network_simplex_simple.h (renamed from ot/emd/network_simplex_simple.h)0
-rw-r--r--ot/optim.py2
-rwxr-xr-xsetup.py6
12 files changed, 11 insertions, 9 deletions
diff --git a/ot/__init__.py b/ot/__init__.py
index 4c3c40e..87119e5 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -5,15 +5,17 @@ from . import utils
from . import datasets
from . import plot
from . import bregman
+from . import lp
from . import da
from . import optim
+
# OT functions
-from ot.emd import emd
+from ot.lp import emd
from ot.bregman import sinkhorn,barycenter
from ot.da import sinkhorn_lpl1_mm
# utils functions
from utils import dist,unif
-__all__ = ["emd","sinkhorn","utils",'datasets','bregman','plot','dist','unif','barycenter','sinkhorn_lpl1_mm','da','optim']
+__all__ = ["emd","sinkhorn","utils",'datasets','bregman','lp','plot','dist','unif','barycenter','sinkhorn_lpl1_mm','da','optim']
diff --git a/ot/emd/__init__.py b/ot/emd/__init__.py
deleted file mode 100644
index 4c38ffa..0000000
--- a/ot/emd/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-
-
-from emd import emd
diff --git a/ot/emd/EMD.h b/ot/lp/EMD.h
index 40d7192..40d7192 100644
--- a/ot/emd/EMD.h
+++ b/ot/lp/EMD.h
diff --git a/ot/emd/EMD_wrap.cpp b/ot/lp/EMD_wrap.cpp
index 52cd262..52cd262 100644
--- a/ot/emd/EMD_wrap.cpp
+++ b/ot/lp/EMD_wrap.cpp
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
new file mode 100644
index 0000000..65cee83
--- /dev/null
+++ b/ot/lp/__init__.py
@@ -0,0 +1,3 @@
+
+
+from . import emd
diff --git a/ot/emd/core.h b/ot/lp/core.h
index 04dddf7..04dddf7 100644
--- a/ot/emd/core.h
+++ b/ot/lp/core.h
diff --git a/ot/emd/emd.cpp b/ot/lp/emd.cpp
index 2343af6..2343af6 100644
--- a/ot/emd/emd.cpp
+++ b/ot/lp/emd.cpp
diff --git a/ot/emd/emd.pyx b/ot/lp/emd.pyx
index 753b195..753b195 100644
--- a/ot/emd/emd.pyx
+++ b/ot/lp/emd.pyx
diff --git a/ot/emd/full_bipartitegraph.h b/ot/lp/full_bipartitegraph.h
index 87a1bec..87a1bec 100644
--- a/ot/emd/full_bipartitegraph.h
+++ b/ot/lp/full_bipartitegraph.h
diff --git a/ot/emd/network_simplex_simple.h b/ot/lp/network_simplex_simple.h
index 64856a0..64856a0 100644
--- a/ot/emd/network_simplex_simple.h
+++ b/ot/lp/network_simplex_simple.h
diff --git a/ot/optim.py b/ot/optim.py
index 66bbfb5..e6373ce 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -8,7 +8,7 @@ Created on Wed Oct 26 15:08:19 2016
import numpy as np
import scipy as sp
from scipy.optimize.linesearch import scalar_search_armijo
-from emd import emd
+from lp import emd
# The corresponding scipy function does not work for matrices
def line_search_armijo(f,xk,pk,gfk,old_fval,args=(),c1=1e-4,alpha0=0.99):
diff --git a/setup.py b/setup.py
index 969ec99..60b9a9f 100755
--- a/setup.py
+++ b/setup.py
@@ -30,15 +30,15 @@ setup(name='POT',
packages=find_packages(),
ext_modules = cythonize(Extension(
"ot.emd.emd", # the extension name
- sources=["ot/emd/emd.pyx", "ot/emd/EMD_wrap.cpp"], # the Cython source and
+ sources=["ot/lp/emd.pyx", "ot/lp/EMD_wrap.cpp"], # the Cython source and
# additional C++ source files
language="c++", # generate and compile C++ code,
- include_dirs=[numpy.get_include(),os.path.join(ROOT,'ot/emd')])),
+ include_dirs=[numpy.get_include(),os.path.join(ROOT,'ot/lp')])),
platforms=['linux','macosx','windows'],
license = 'MIT',
scripts=[],
data_files=[],
- requires=["numpy (>=1.11)"],
+ requires=["numpy (>=1.11)","scipy (>=0.17)"],
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',