summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
authorKilian Fatras <kilianfatras@Kilians-MacBook-Air.local>2019-04-04 13:58:50 +0200
committerKilian Fatras <kilianfatras@Kilians-MacBook-Air.local>2019-04-04 13:58:50 +0200
commit69186a6f4259d32fecac370f59efe16e2e460d04 (patch)
tree7577b7e5c6faaf3ce8eebed8b56ce8af1e7614b1 /ot/bregman.py
parent780bdfee3c622698dc9b18a02fa06381314aa56d (diff)
fix test sinkhorn div
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py11
1 files changed, 8 insertions, 3 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 7acfcf1..dc43834 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -1587,8 +1587,13 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
log['log_sinkhorn_b'] = log_b
return max(0, sinkhorn_div), log
+
else:
- sinkhorn_div = (empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) -
- 1 / 2 * empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) -
- 1 / 2 * empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs))
+ sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+
+ sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+
+ sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+
+ sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b)
return max(0, sinkhorn_div)