summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ot/da.py29
-rwxr-xr-xsetup.py17
2 files changed, 31 insertions, 15 deletions
diff --git a/ot/da.py b/ot/da.py
index b83d67e..48b418f 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -1605,13 +1605,14 @@ class EMDTransport(BaseTransport):
on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
"""
- def __init__(self, metric="sqeuclidean", norm=None,
+ def __init__(self, metric="sqeuclidean", norm=None, log=False,
distribution_estimation=distribution_estimation_uniform,
out_of_sample_map='ferradans', limit_max=10,
max_iter=100000):
self.metric = metric
self.norm = norm
+ self.log = log
self.limit_max = limit_max
self.distribution_estimation = distribution_estimation
self.out_of_sample_map = out_of_sample_map
@@ -1644,11 +1645,16 @@ class EMDTransport(BaseTransport):
super(EMDTransport, self).fit(Xs, ys, Xt, yt)
- # coupling estimation
- self.coupling_ = emd(
- a=self.mu_s, b=self.mu_t, M=self.cost_, numItermax=self.max_iter
- )
+ returned_ = emd(
+ a=self.mu_s, b=self.mu_t, M=self.cost_, numItermax=self.max_iter,
+ log=self.log)
+ # coupling estimation
+ if self.log:
+ self.coupling_, self.log_ = returned_
+ else:
+ self.coupling_ = returned_
+ self.log_ = dict()
return self
@@ -1705,7 +1711,7 @@ class SinkhornLpl1Transport(BaseTransport):
"""
def __init__(self, reg_e=1., reg_cl=0.1,
- max_iter=10, max_inner_iter=200,
+ max_iter=10, max_inner_iter=200, log=False,
tol=10e-9, verbose=False,
metric="sqeuclidean", norm=None,
distribution_estimation=distribution_estimation_uniform,
@@ -1716,6 +1722,7 @@ class SinkhornLpl1Transport(BaseTransport):
self.max_iter = max_iter
self.max_inner_iter = max_inner_iter
self.tol = tol
+ self.log = log
self.verbose = verbose
self.metric = metric
self.norm = norm
@@ -1753,12 +1760,18 @@ class SinkhornLpl1Transport(BaseTransport):
super(SinkhornLpl1Transport, self).fit(Xs, ys, Xt, yt)
- self.coupling_ = sinkhorn_lpl1_mm(
+ returned_ = sinkhorn_lpl1_mm(
a=self.mu_s, labels_a=ys, b=self.mu_t, M=self.cost_,
reg=self.reg_e, eta=self.reg_cl, numItermax=self.max_iter,
numInnerItermax=self.max_inner_iter, stopInnerThr=self.tol,
- verbose=self.verbose)
+ verbose=self.verbose, log=self.log)
+ # deal with the value of log
+ if self.log:
+ self.coupling_, self.log_ = returned_
+ else:
+ self.coupling_ = returned_
+ self.log_ = dict()
return self
diff --git a/setup.py b/setup.py
index a32aa31..3066848 100755
--- a/setup.py
+++ b/setup.py
@@ -18,21 +18,18 @@ __version__ = re.search(
# The beautiful part is, I don't even need to check exceptions here.
# If something messes up, let the build process fail noisy, BEFORE my release!
+# thanks Pipy for handling markdown now
ROOT = os.path.abspath(os.path.dirname(__file__))
-
-# convert markdown readme to rst if pypandoc installed
-try:
- import pypandoc
- README = pypandoc.convert('README.md', 'rst')
-except (IOError, ImportError):
- README = open(os.path.join(ROOT, 'README.md'), encoding="utf-8").read()
+with open(os.path.join(ROOT, 'README.md'), encoding="utf-8") as f:
+ README = f.read()
setup(name='POT',
version=__version__,
description='Python Optimal Transport Library',
long_description=README,
+ long_description_content_type='text/markdown',
author=u'Remi Flamary, Nicolas Courty',
author_email='remi.flamary@gmail.com, ncourty@gmail.com',
url='https://github.com/rflamary/POT',
@@ -59,5 +56,11 @@ setup(name='POT',
'Operating System :: POSIX',
'Programming Language :: Python',
'Topic :: Utilities'
+ 'Programming Language :: Python :: 2',
+ 'Programming Language :: Python :: 2.7',
+ 'Programming Language :: Python :: 3',
+ 'Programming Language :: Python :: 3.4',
+ 'Programming Language :: Python :: 3.5',
+ 'Programming Language :: Python :: 3.6',
]
)