From 69186a6f4259d32fecac370f59efe16e2e460d04 Mon Sep 17 00:00:00 2001 From: Kilian Fatras Date: Thu, 4 Apr 2019 13:58:50 +0200 Subject: fix test sinkhorn div --- ot/bregman.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) (limited to 'ot') diff --git a/ot/bregman.py b/ot/bregman.py index 7acfcf1..dc43834 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1587,8 +1587,13 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli log['log_sinkhorn_b'] = log_b return max(0, sinkhorn_div), log + else: - sinkhorn_div = (empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - - 1 / 2 * empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - - 1 / 2 * empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)) + sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + + sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + + sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + + sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b) return max(0, sinkhorn_div) -- cgit v1.2.3