diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-27 12:34:42 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-27 12:34:42 +0200 |
commit | e083f90ad09a3bd42beffea1e996f3b4a9b3ff76 (patch) | |
tree | f329e51af871ef1f415a87d4f9820c50c03fc4fc | |
parent | 708aadb3396129c56cf128be04b7e87304b95070 (diff) |
rename emd module to lp
-rw-r--r-- | ot/__init__.py | 6 | ||||
-rw-r--r-- | ot/emd/__init__.py | 3 | ||||
-rw-r--r-- | ot/lp/EMD.h (renamed from ot/emd/EMD.h) | 0 | ||||
-rw-r--r-- | ot/lp/EMD_wrap.cpp (renamed from ot/emd/EMD_wrap.cpp) | 0 | ||||
-rw-r--r-- | ot/lp/__init__.py | 3 | ||||
-rw-r--r-- | ot/lp/core.h (renamed from ot/emd/core.h) | 0 | ||||
-rw-r--r-- | ot/lp/emd.cpp (renamed from ot/emd/emd.cpp) | 0 | ||||
-rw-r--r-- | ot/lp/emd.pyx (renamed from ot/emd/emd.pyx) | 0 | ||||
-rw-r--r-- | ot/lp/full_bipartitegraph.h (renamed from ot/emd/full_bipartitegraph.h) | 0 | ||||
-rw-r--r-- | ot/lp/network_simplex_simple.h (renamed from ot/emd/network_simplex_simple.h) | 0 | ||||
-rw-r--r-- | ot/optim.py | 2 | ||||
-rwxr-xr-x | setup.py | 6 |
12 files changed, 11 insertions, 9 deletions
diff --git a/ot/__init__.py b/ot/__init__.py index 4c3c40e..87119e5 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -5,15 +5,17 @@ from . import utils from . import datasets from . import plot from . import bregman +from . import lp from . import da from . import optim + # OT functions -from ot.emd import emd +from ot.lp import emd from ot.bregman import sinkhorn,barycenter from ot.da import sinkhorn_lpl1_mm # utils functions from utils import dist,unif -__all__ = ["emd","sinkhorn","utils",'datasets','bregman','plot','dist','unif','barycenter','sinkhorn_lpl1_mm','da','optim'] +__all__ = ["emd","sinkhorn","utils",'datasets','bregman','lp','plot','dist','unif','barycenter','sinkhorn_lpl1_mm','da','optim'] diff --git a/ot/emd/__init__.py b/ot/emd/__init__.py deleted file mode 100644 index 4c38ffa..0000000 --- a/ot/emd/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ - - -from emd import emd diff --git a/ot/emd/EMD.h b/ot/lp/EMD.h index 40d7192..40d7192 100644 --- a/ot/emd/EMD.h +++ b/ot/lp/EMD.h diff --git a/ot/emd/EMD_wrap.cpp b/ot/lp/EMD_wrap.cpp index 52cd262..52cd262 100644 --- a/ot/emd/EMD_wrap.cpp +++ b/ot/lp/EMD_wrap.cpp diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py new file mode 100644 index 0000000..65cee83 --- /dev/null +++ b/ot/lp/__init__.py @@ -0,0 +1,3 @@ + + +from . import emd diff --git a/ot/emd/core.h b/ot/lp/core.h index 04dddf7..04dddf7 100644 --- a/ot/emd/core.h +++ b/ot/lp/core.h diff --git a/ot/emd/emd.cpp b/ot/lp/emd.cpp index 2343af6..2343af6 100644 --- a/ot/emd/emd.cpp +++ b/ot/lp/emd.cpp diff --git a/ot/emd/emd.pyx b/ot/lp/emd.pyx index 753b195..753b195 100644 --- a/ot/emd/emd.pyx +++ b/ot/lp/emd.pyx diff --git a/ot/emd/full_bipartitegraph.h b/ot/lp/full_bipartitegraph.h index 87a1bec..87a1bec 100644 --- a/ot/emd/full_bipartitegraph.h +++ b/ot/lp/full_bipartitegraph.h diff --git a/ot/emd/network_simplex_simple.h b/ot/lp/network_simplex_simple.h index 64856a0..64856a0 100644 --- a/ot/emd/network_simplex_simple.h +++ b/ot/lp/network_simplex_simple.h diff --git a/ot/optim.py b/ot/optim.py index 66bbfb5..e6373ce 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -8,7 +8,7 @@ Created on Wed Oct 26 15:08:19 2016 import numpy as np import scipy as sp from scipy.optimize.linesearch import scalar_search_armijo -from emd import emd +from lp import emd # The corresponding scipy function does not work for matrices def line_search_armijo(f,xk,pk,gfk,old_fval,args=(),c1=1e-4,alpha0=0.99): @@ -30,15 +30,15 @@ setup(name='POT', packages=find_packages(), ext_modules = cythonize(Extension( "ot.emd.emd", # the extension name - sources=["ot/emd/emd.pyx", "ot/emd/EMD_wrap.cpp"], # the Cython source and + sources=["ot/lp/emd.pyx", "ot/lp/EMD_wrap.cpp"], # the Cython source and # additional C++ source files language="c++", # generate and compile C++ code, - include_dirs=[numpy.get_include(),os.path.join(ROOT,'ot/emd')])), + include_dirs=[numpy.get_include(),os.path.join(ROOT,'ot/lp')])), platforms=['linux','macosx','windows'], license = 'MIT', scripts=[], data_files=[], - requires=["numpy (>=1.11)"], + requires=["numpy (>=1.11)","scipy (>=0.17)"], classifiers=[ 'Development Status :: 4 - Beta', 'Intended Audience :: Developers', |