summaryrefslogtreecommitdiff
path: root/ot/backend.py
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2022-12-15 09:28:01 +0100
committerGitHub <noreply@github.com>2022-12-15 09:28:01 +0100
commit0411ea22a96f9c22af30156b45c16ef39ffb520d (patch)
tree7c131ad804d5b16a8c362c2fe296350a770400df /ot/backend.py
parent8490196dcc982c492b7565e1ec4de5f75f006acf (diff)
[MRG] New API for OT solver (with pre-computed ground cost matrix) (#388)
* new API for OT solver * use itertools for product of parameters * add tests for result class * add tests for result class * add tests for result class last time? * add sinkhorn * make partial OT bckend compatible * add TV as unbalanced flavor * better tests * make smoth backend compatible and add l2 tregularizatio to solve * add reularizedd unbalanced * add test for more complex attibutes * add test for more complex attibutes * add generic unbalaned solver and implement it for ot.solve * add entropy to possible regularization * star of documentation for ot.solv * weird new pep8 * documenttaion for function ot.solve done * pep8 * Update ot/solvers.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * update release file * Apply suggestions from code review Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * add test NotImplemented * pep8 * pep8gcmp pep8! * compute kl in backend * debug tensorflow kl backend Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'ot/backend.py')
-rw-r--r--ot/backend.py30
1 files changed, 30 insertions, 0 deletions
diff --git a/ot/backend.py b/ot/backend.py
index e4b48e1..337e040 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -854,6 +854,21 @@ class Backend():
"""
raise NotImplementedError()
+ def kl_div(self, p, q, eps=1e-16):
+ r"""
+ Computes the Kullback-Leibler divergence.
+
+ This function follows the api from :any:`scipy.stats.entropy`.
+
+ Parameter eps is used to avoid numerical errors and is added in the log.
+
+ .. math::
+ KL(p,q) = \sum_i p(i) \log (\frac{p(i)}{q(i)}+\epsilon)
+
+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.entropy.html
+ """
+ raise NotImplementedError()
+
def isfinite(self, a):
r"""
Tests element-wise for finiteness (not infinity and not Not a Number).
@@ -1158,6 +1173,9 @@ class NumpyBackend(Backend):
def sqrtm(self, a):
return scipy.linalg.sqrtm(a)
+ def kl_div(self, p, q, eps=1e-16):
+ return np.sum(p * np.log(p / q + eps))
+
def isfinite(self, a):
return np.isfinite(a)
@@ -1481,6 +1499,9 @@ class JaxBackend(Backend):
L, V = jnp.linalg.eigh(a)
return (V * jnp.sqrt(L)[None, :]) @ V.T
+ def kl_div(self, p, q, eps=1e-16):
+ return jnp.sum(p * jnp.log(p / q + eps))
+
def isfinite(self, a):
return jnp.isfinite(a)
@@ -1901,6 +1922,9 @@ class TorchBackend(Backend):
L, V = torch.linalg.eigh(a)
return (V * torch.sqrt(L)[None, :]) @ V.T
+ def kl_div(self, p, q, eps=1e-16):
+ return torch.sum(p * torch.log(p / q + eps))
+
def isfinite(self, a):
return torch.isfinite(a)
@@ -2248,6 +2272,9 @@ class CupyBackend(Backend): # pragma: no cover
L, V = cp.linalg.eigh(a)
return (V * self.sqrt(L)[None, :]) @ V.T
+ def kl_div(self, p, q, eps=1e-16):
+ return cp.sum(p * cp.log(p / q + eps))
+
def isfinite(self, a):
return cp.isfinite(a)
@@ -2608,6 +2635,9 @@ class TensorflowBackend(Backend):
def sqrtm(self, a):
return tf.linalg.sqrtm(a)
+ def kl_div(self, p, q, eps=1e-16):
+ return tnp.sum(p * tnp.log(p / q + eps))
+
def isfinite(self, a):
return tnp.isfinite(a)