diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2022-12-15 09:28:01 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-12-15 09:28:01 +0100 |
commit | 0411ea22a96f9c22af30156b45c16ef39ffb520d (patch) | |
tree | 7c131ad804d5b16a8c362c2fe296350a770400df /ot/backend.py | |
parent | 8490196dcc982c492b7565e1ec4de5f75f006acf (diff) |
[MRG] New API for OT solver (with pre-computed ground cost matrix) (#388)
* new API for OT solver
* use itertools for product of parameters
* add tests for result class
* add tests for result class
* add tests for result class last time?
* add sinkhorn
* make partial OT bckend compatible
* add TV as unbalanced flavor
* better tests
* make smoth backend compatible and add l2 tregularizatio to solve
* add reularizedd unbalanced
* add test for more complex attibutes
* add test for more complex attibutes
* add generic unbalaned solver and implement it for ot.solve
* add entropy to possible regularization
* star of documentation for ot.solv
* weird new pep8
* documenttaion for function ot.solve done
* pep8
* Update ot/solvers.py
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* update release file
* Apply suggestions from code review
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* add test NotImplemented
* pep8
* pep8gcmp pep8!
* compute kl in backend
* debug tensorflow kl backend
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'ot/backend.py')
-rw-r--r-- | ot/backend.py | 30 |
1 files changed, 30 insertions, 0 deletions
diff --git a/ot/backend.py b/ot/backend.py index e4b48e1..337e040 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -854,6 +854,21 @@ class Backend(): """ raise NotImplementedError() + def kl_div(self, p, q, eps=1e-16): + r""" + Computes the Kullback-Leibler divergence. + + This function follows the api from :any:`scipy.stats.entropy`. + + Parameter eps is used to avoid numerical errors and is added in the log. + + .. math:: + KL(p,q) = \sum_i p(i) \log (\frac{p(i)}{q(i)}+\epsilon) + + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.entropy.html + """ + raise NotImplementedError() + def isfinite(self, a): r""" Tests element-wise for finiteness (not infinity and not Not a Number). @@ -1158,6 +1173,9 @@ class NumpyBackend(Backend): def sqrtm(self, a): return scipy.linalg.sqrtm(a) + def kl_div(self, p, q, eps=1e-16): + return np.sum(p * np.log(p / q + eps)) + def isfinite(self, a): return np.isfinite(a) @@ -1481,6 +1499,9 @@ class JaxBackend(Backend): L, V = jnp.linalg.eigh(a) return (V * jnp.sqrt(L)[None, :]) @ V.T + def kl_div(self, p, q, eps=1e-16): + return jnp.sum(p * jnp.log(p / q + eps)) + def isfinite(self, a): return jnp.isfinite(a) @@ -1901,6 +1922,9 @@ class TorchBackend(Backend): L, V = torch.linalg.eigh(a) return (V * torch.sqrt(L)[None, :]) @ V.T + def kl_div(self, p, q, eps=1e-16): + return torch.sum(p * torch.log(p / q + eps)) + def isfinite(self, a): return torch.isfinite(a) @@ -2248,6 +2272,9 @@ class CupyBackend(Backend): # pragma: no cover L, V = cp.linalg.eigh(a) return (V * self.sqrt(L)[None, :]) @ V.T + def kl_div(self, p, q, eps=1e-16): + return cp.sum(p * cp.log(p / q + eps)) + def isfinite(self, a): return cp.isfinite(a) @@ -2608,6 +2635,9 @@ class TensorflowBackend(Backend): def sqrtm(self, a): return tf.linalg.sqrtm(a) + def kl_div(self, p, q, eps=1e-16): + return tnp.sum(p * tnp.log(p / q + eps)) + def isfinite(self, a): return tnp.isfinite(a) |