diff options
Diffstat (limited to 'ot/bregman.py')
-rw-r--r-- | ot/bregman.py | 26 |
1 files changed, 22 insertions, 4 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index 47554fb..7acfcf1 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1569,8 +1569,26 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018 ''' + if log: + sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + + sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + + sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - sinkhorn_div = (2 * empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - - empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - - empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)) - return max(0, sinkhorn_div) + sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b) + + log = {} + log['sinkhorn_loss_ab'] = sinkhorn_loss_ab + log['sinkhorn_loss_a'] = sinkhorn_loss_a + log['sinkhorn_loss_b'] = sinkhorn_loss_b + log['log_sinkhorn_ab'] = log_ab + log['log_sinkhorn_a'] = log_a + log['log_sinkhorn_b'] = log_b + + return max(0, sinkhorn_div), log + else: + sinkhorn_div = (empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - + 1 / 2 * empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - + 1 / 2 * empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)) + return max(0, sinkhorn_div) |