summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ot/bregman.py26
-rw-r--r--test/test_bregman.py12
2 files changed, 33 insertions, 5 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 47554fb..7acfcf1 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -1569,8 +1569,26 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
.. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018
'''
+ if log:
+ sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+
+ sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+
+ sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
- sinkhorn_div = (2 * empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) -
- empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) -
- empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs))
- return max(0, sinkhorn_div)
+ sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b)
+
+ log = {}
+ log['sinkhorn_loss_ab'] = sinkhorn_loss_ab
+ log['sinkhorn_loss_a'] = sinkhorn_loss_a
+ log['sinkhorn_loss_b'] = sinkhorn_loss_b
+ log['log_sinkhorn_ab'] = log_ab
+ log['log_sinkhorn_a'] = log_a
+ log['log_sinkhorn_b'] = log_b
+
+ return max(0, sinkhorn_div), log
+ else:
+ sinkhorn_div = (empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) -
+ 1 / 2 * empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) -
+ 1 / 2 * empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs))
+ return max(0, sinkhorn_div)
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 0ebd546..68d3595 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -204,6 +204,9 @@ def test_empirical_sinkhorn():
G_sqe = ot.bregman.empirical_sinkhorn(X_s, X_t, 1)
sinkhorn_sqe = ot.sinkhorn(a, b, M, 1)
+ G_log, log_es = ot.bregman.empirical_sinkhorn(X_s, X_t, 0.1, log=True)
+ sinkhorn_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True)
+
G_m = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, metric='minkowski')
sinkhorn_m = ot.sinkhorn(a, b, M_m, 1)
@@ -216,6 +219,10 @@ def test_empirical_sinkhorn():
np.testing.assert_allclose(
sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian
np.testing.assert_allclose(
+ sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log
+ np.testing.assert_allclose(
+ sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log
+ np.testing.assert_allclose(
sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05) # metric euclidian
np.testing.assert_allclose(
sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian
@@ -237,8 +244,11 @@ def test_empirical_sinkhorn_divergence():
sinkhorn_div = (2 * ot.sinkhorn2(a, b, M, 1) - ot.sinkhorn2(a, a, M_s, 1) -
ot.sinkhorn2(b, b, M_t, 1))
+ emp_sinkhorn_div_log, log_es = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 0.1, log=True)
+ sinkhorn_div_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True)
+
# check constratints
np.testing.assert_allclose(
emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn
np.testing.assert_allclose(
- emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn
+ emp_sinkhorn_div_log, sinkhorn_div_log, atol=1e-05) # cf conv emp sinkhorn