summaryrefslogtreecommitdiff
path: root/ot/backend.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/backend.py')
-rw-r--r--ot/backend.py30
1 files changed, 30 insertions, 0 deletions
diff --git a/ot/backend.py b/ot/backend.py
index e4b48e1..337e040 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -854,6 +854,21 @@ class Backend():
"""
raise NotImplementedError()
+ def kl_div(self, p, q, eps=1e-16):
+ r"""
+ Computes the Kullback-Leibler divergence.
+
+ This function follows the api from :any:`scipy.stats.entropy`.
+
+ Parameter eps is used to avoid numerical errors and is added in the log.
+
+ .. math::
+ KL(p,q) = \sum_i p(i) \log (\frac{p(i)}{q(i)}+\epsilon)
+
+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.entropy.html
+ """
+ raise NotImplementedError()
+
def isfinite(self, a):
r"""
Tests element-wise for finiteness (not infinity and not Not a Number).
@@ -1158,6 +1173,9 @@ class NumpyBackend(Backend):
def sqrtm(self, a):
return scipy.linalg.sqrtm(a)
+ def kl_div(self, p, q, eps=1e-16):
+ return np.sum(p * np.log(p / q + eps))
+
def isfinite(self, a):
return np.isfinite(a)
@@ -1481,6 +1499,9 @@ class JaxBackend(Backend):
L, V = jnp.linalg.eigh(a)
return (V * jnp.sqrt(L)[None, :]) @ V.T
+ def kl_div(self, p, q, eps=1e-16):
+ return jnp.sum(p * jnp.log(p / q + eps))
+
def isfinite(self, a):
return jnp.isfinite(a)
@@ -1901,6 +1922,9 @@ class TorchBackend(Backend):
L, V = torch.linalg.eigh(a)
return (V * torch.sqrt(L)[None, :]) @ V.T
+ def kl_div(self, p, q, eps=1e-16):
+ return torch.sum(p * torch.log(p / q + eps))
+
def isfinite(self, a):
return torch.isfinite(a)
@@ -2248,6 +2272,9 @@ class CupyBackend(Backend): # pragma: no cover
L, V = cp.linalg.eigh(a)
return (V * self.sqrt(L)[None, :]) @ V.T
+ def kl_div(self, p, q, eps=1e-16):
+ return cp.sum(p * cp.log(p / q + eps))
+
def isfinite(self, a):
return cp.isfinite(a)
@@ -2608,6 +2635,9 @@ class TensorflowBackend(Backend):
def sqrtm(self, a):
return tf.linalg.sqrtm(a)
+ def kl_div(self, p, q, eps=1e-16):
+ return tnp.sum(p * tnp.log(p / q + eps))
+
def isfinite(self, a):
return tnp.isfinite(a)