diff options
-rw-r--r-- | ot/da.py | 29 | ||||
-rwxr-xr-x | setup.py | 17 |
2 files changed, 31 insertions, 15 deletions
@@ -1605,13 +1605,14 @@ class EMDTransport(BaseTransport): on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 """ - def __init__(self, metric="sqeuclidean", norm=None, + def __init__(self, metric="sqeuclidean", norm=None, log=False, distribution_estimation=distribution_estimation_uniform, out_of_sample_map='ferradans', limit_max=10, max_iter=100000): self.metric = metric self.norm = norm + self.log = log self.limit_max = limit_max self.distribution_estimation = distribution_estimation self.out_of_sample_map = out_of_sample_map @@ -1644,11 +1645,16 @@ class EMDTransport(BaseTransport): super(EMDTransport, self).fit(Xs, ys, Xt, yt) - # coupling estimation - self.coupling_ = emd( - a=self.mu_s, b=self.mu_t, M=self.cost_, numItermax=self.max_iter - ) + returned_ = emd( + a=self.mu_s, b=self.mu_t, M=self.cost_, numItermax=self.max_iter, + log=self.log) + # coupling estimation + if self.log: + self.coupling_, self.log_ = returned_ + else: + self.coupling_ = returned_ + self.log_ = dict() return self @@ -1705,7 +1711,7 @@ class SinkhornLpl1Transport(BaseTransport): """ def __init__(self, reg_e=1., reg_cl=0.1, - max_iter=10, max_inner_iter=200, + max_iter=10, max_inner_iter=200, log=False, tol=10e-9, verbose=False, metric="sqeuclidean", norm=None, distribution_estimation=distribution_estimation_uniform, @@ -1716,6 +1722,7 @@ class SinkhornLpl1Transport(BaseTransport): self.max_iter = max_iter self.max_inner_iter = max_inner_iter self.tol = tol + self.log = log self.verbose = verbose self.metric = metric self.norm = norm @@ -1753,12 +1760,18 @@ class SinkhornLpl1Transport(BaseTransport): super(SinkhornLpl1Transport, self).fit(Xs, ys, Xt, yt) - self.coupling_ = sinkhorn_lpl1_mm( + returned_ = sinkhorn_lpl1_mm( a=self.mu_s, labels_a=ys, b=self.mu_t, M=self.cost_, reg=self.reg_e, eta=self.reg_cl, numItermax=self.max_iter, numInnerItermax=self.max_inner_iter, stopInnerThr=self.tol, - verbose=self.verbose) + verbose=self.verbose, log=self.log) + # deal with the value of log + if self.log: + self.coupling_, self.log_ = returned_ + else: + self.coupling_ = returned_ + self.log_ = dict() return self @@ -18,21 +18,18 @@ __version__ = re.search( # The beautiful part is, I don't even need to check exceptions here. # If something messes up, let the build process fail noisy, BEFORE my release! +# thanks Pipy for handling markdown now ROOT = os.path.abspath(os.path.dirname(__file__)) - -# convert markdown readme to rst if pypandoc installed -try: - import pypandoc - README = pypandoc.convert('README.md', 'rst') -except (IOError, ImportError): - README = open(os.path.join(ROOT, 'README.md'), encoding="utf-8").read() +with open(os.path.join(ROOT, 'README.md'), encoding="utf-8") as f: + README = f.read() setup(name='POT', version=__version__, description='Python Optimal Transport Library', long_description=README, + long_description_content_type='text/markdown', author=u'Remi Flamary, Nicolas Courty', author_email='remi.flamary@gmail.com, ncourty@gmail.com', url='https://github.com/rflamary/POT', @@ -59,5 +56,11 @@ setup(name='POT', 'Operating System :: POSIX', 'Programming Language :: Python', 'Topic :: Utilities' + 'Programming Language :: Python :: 2', + 'Programming Language :: Python :: 2.7', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.4', + 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', ] ) |