summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFrançois Rozet <francois.rozet@outlook.com>2023-05-04 08:10:01 +0200
committerGitHub <noreply@github.com>2023-05-04 08:10:01 +0200
commit83dc498b496087aea293df1445442d8728435211 (patch)
treeb7410868d7d4b25daaa3e478f71b966fe073f24c
parent2aeb591be6b19a93f187516495ed15f1a47be925 (diff)
Improve Bures-Wasserstein distance (#468)
* Improve Bures-Wasserstein distance * Revert changes and modify sqrtm * Fix typo * Add changes to RELEASES.md --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
-rw-r--r--RELEASES.md1
-rw-r--r--ot/backend.py8
-rw-r--r--ot/gaussian.py4
3 files changed, 8 insertions, 5 deletions
diff --git a/RELEASES.md b/RELEASES.md
index 3366e2a..586089b 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -11,6 +11,7 @@
- Fix issues with cuda for ot.binary_search_circle and with gradients for ot.sliced_wasserstein_sphere (PR #457)
- Major documentation cleanup (PR #462, #467)
- Fix gradients for "Wasserstein2 Minibatch GAN" example (PR #466)
+- Faster Bures-Wasserstein distance with NumPy backend (PR #468)
## 0.9.0
diff --git a/ot/backend.py b/ot/backend.py
index a82c448..eecf9dd 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -1235,7 +1235,8 @@ class NumpyBackend(Backend):
return scipy.linalg.inv(a)
def sqrtm(self, a):
- return scipy.linalg.sqrtm(a)
+ L, V = np.linalg.eigh(a)
+ return (V * np.sqrt(L)[None, :]) @ V.T
def kl_div(self, p, q, eps=1e-16):
return np.sum(p * np.log(p / q + eps))
@@ -2433,7 +2434,7 @@ class CupyBackend(Backend): # pragma: no cover
def sqrtm(self, a):
L, V = cp.linalg.eigh(a)
- return (V * self.sqrt(L)[None, :]) @ V.T
+ return (V * cp.sqrt(L)[None, :]) @ V.T
def kl_div(self, p, q, eps=1e-16):
return cp.sum(p * cp.log(p / q + eps))
@@ -2824,7 +2825,8 @@ class TensorflowBackend(Backend):
return tf.linalg.inv(a)
def sqrtm(self, a):
- return tf.linalg.sqrtm(a)
+ L, V = tf.linalg.eigh(a)
+ return (V * tf.sqrt(L)[None, :]) @ V.T
def kl_div(self, p, q, eps=1e-16):
return tnp.sum(p * tnp.log(p / q + eps))
diff --git a/ot/gaussian.py b/ot/gaussian.py
index 4ffb726..1a29556 100644
--- a/ot/gaussian.py
+++ b/ot/gaussian.py
@@ -202,7 +202,7 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False):
where :
.. math::
- \mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s^{1/2} + \Sigma_t^{1/2} - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right)
+ \mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s + \Sigma_t - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right)
Parameters
----------
@@ -264,7 +264,7 @@ def empirical_bures_wasserstein_distance(xs, xt, reg=1e-6, ws=None,
where :
.. math::
- \mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s^{1/2} + \Sigma_t^{1/2} - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right)
+ \mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s + \Sigma_t - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right)
Parameters
----------