diff options
author | Hicham Janati <hicham.janati@inria.fr> | 2019-06-12 17:06:32 +0200 |
---|---|---|
committer | Hicham Janati <hicham.janati@inria.fr> | 2019-06-12 17:06:32 +0200 |
commit | 11381a7ecc79ef719ee9107167c3adc22b5a3f59 (patch) | |
tree | fd2b3b7c4ae59bc4050e6b69579f940e2a5a5f18 /ot | |
parent | 28b549ef3ef93c01462cd811d6e55c36ae5a76a2 (diff) |
integrate comments of jmassich
Diffstat (limited to 'ot')
-rw-r--r-- | ot/unbalanced.py | 54 |
1 files changed, 16 insertions, 38 deletions
diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 8bd02eb..f4208b5 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -6,6 +6,7 @@ Regularized Unbalanced OT # Author: Hicham Janati <hicham.janati@inria.fr> # License: MIT License +import warnings import numpy as np # from .utils import unif, dist @@ -29,7 +30,7 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, - a and b are source and target weights - KL is the Kullback-Leibler divergence - The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ + The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ Parameters @@ -85,15 +86,14 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [23] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 See Also -------- - ot.lp.emd : Unregularized OT - ot.optim.cg : General regularized OT - ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2] - ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10] - ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10] + ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10] + ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10] + ot.unbalanced.sinkhorn_epsilon_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10] """ @@ -101,17 +101,8 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, def sink(): return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) - # elif method.lower() == 'sinkhorn_stabilized': - # def sink(): - # return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, - # stopThr=stopThr, verbose=verbose, log=log, **kwargs) - # elif method.lower() == 'sinkhorn_epsilon_scaling': - # def sink(): - # return sinkhorn_epsilon_scaling( - # a, b, M, reg, numItermax=numItermax, - # stopThr=stopThr, verbose=verbose, log=log, **kwargs) else: - print('Warning : unknown method. Falling back to classic Sinkhorn Knopp') + warnings.warn('Unknown method. Falling back to classic Sinkhorn Knopp') def sink(): return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax, @@ -139,7 +130,7 @@ def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, - a and b are source and target weights - KL is the Kullback-Leibler divergence - The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ + The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ Parameters @@ -196,18 +187,13 @@ def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. - [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017 - - + .. [23] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 See Also -------- - ot.lp.emd : Unregularized OT - ot.optim.cg : General regularized OT - ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2] - ot.bregman.greenkhorn : Greenkhorn [21] - ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10] - ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10] + ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10] + ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10] + ot.unbalanced.sinkhorn_epsilon_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10] """ @@ -215,17 +201,8 @@ def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, def sink(): return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) - # elif method.lower() == 'sinkhorn_stabilized': - # def sink(): - # return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, - # stopThr=stopThr, verbose=verbose, log=log, **kwargs) - # elif method.lower() == 'sinkhorn_epsilon_scaling': - # def sink(): - # return sinkhorn_epsilon_scaling( - # a, b, M, reg, numItermax=numItermax, - # stopThr=stopThr, verbose=verbose, log=log, **kwargs) else: - print('Warning : unknown method using classic Sinkhorn Knopp') + warnings.warn('Unknown method using classic Sinkhorn Knopp') def sink(): return sinkhorn_knopp(a, b, M, reg, alpha, **kwargs) @@ -256,7 +233,7 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000, - a and b are source and target weights - KL is the Kullback-Leibler divergence - The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ + The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ Parameters @@ -306,6 +283,7 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000, .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [23] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 See Also -------- @@ -368,7 +346,7 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000, or np.any(np.isinf(u)) or np.any(np.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop - print('Warning: numerical errors at iteration', cpt) + warnings.warn('Numerical errors at iteration', cpt) u = uprev v = vprev break |