summaryrefslogtreecommitdiff
path: root/ot/unbalanced.py
diff options
context:
space:
mode:
authorHicham Janati <hicham.janati@inria.fr>2019-06-12 17:06:32 +0200
committerHicham Janati <hicham.janati@inria.fr>2019-06-12 17:06:32 +0200
commit11381a7ecc79ef719ee9107167c3adc22b5a3f59 (patch)
treefd2b3b7c4ae59bc4050e6b69579f940e2a5a5f18 /ot/unbalanced.py
parent28b549ef3ef93c01462cd811d6e55c36ae5a76a2 (diff)
integrate comments of jmassich
Diffstat (limited to 'ot/unbalanced.py')
-rw-r--r--ot/unbalanced.py54
1 files changed, 16 insertions, 38 deletions
diff --git a/ot/unbalanced.py b/ot/unbalanced.py
index 8bd02eb..f4208b5 100644
--- a/ot/unbalanced.py
+++ b/ot/unbalanced.py
@@ -6,6 +6,7 @@ Regularized Unbalanced OT
# Author: Hicham Janati <hicham.janati@inria.fr>
# License: MIT License
+import warnings
import numpy as np
# from .utils import unif, dist
@@ -29,7 +30,7 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
- a and b are source and target weights
- KL is the Kullback-Leibler divergence
- The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
+ The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
Parameters
@@ -85,15 +86,14 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [23] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015
See Also
--------
- ot.lp.emd : Unregularized OT
- ot.optim.cg : General regularized OT
- ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2]
- ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10]
- ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]
+ ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10]
+ ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10]
+ ot.unbalanced.sinkhorn_epsilon_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10]
"""
@@ -101,17 +101,8 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
def sink():
return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log, **kwargs)
- # elif method.lower() == 'sinkhorn_stabilized':
- # def sink():
- # return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
- # stopThr=stopThr, verbose=verbose, log=log, **kwargs)
- # elif method.lower() == 'sinkhorn_epsilon_scaling':
- # def sink():
- # return sinkhorn_epsilon_scaling(
- # a, b, M, reg, numItermax=numItermax,
- # stopThr=stopThr, verbose=verbose, log=log, **kwargs)
else:
- print('Warning : unknown method. Falling back to classic Sinkhorn Knopp')
+ warnings.warn('Unknown method. Falling back to classic Sinkhorn Knopp')
def sink():
return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax,
@@ -139,7 +130,7 @@ def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
- a and b are source and target weights
- KL is the Kullback-Leibler divergence
- The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
+ The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
Parameters
@@ -196,18 +187,13 @@ def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
- [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017
-
-
+ .. [23] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015
See Also
--------
- ot.lp.emd : Unregularized OT
- ot.optim.cg : General regularized OT
- ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2]
- ot.bregman.greenkhorn : Greenkhorn [21]
- ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10]
- ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]
+ ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10]
+ ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10]
+ ot.unbalanced.sinkhorn_epsilon_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10]
"""
@@ -215,17 +201,8 @@ def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
def sink():
return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log, **kwargs)
- # elif method.lower() == 'sinkhorn_stabilized':
- # def sink():
- # return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
- # stopThr=stopThr, verbose=verbose, log=log, **kwargs)
- # elif method.lower() == 'sinkhorn_epsilon_scaling':
- # def sink():
- # return sinkhorn_epsilon_scaling(
- # a, b, M, reg, numItermax=numItermax,
- # stopThr=stopThr, verbose=verbose, log=log, **kwargs)
else:
- print('Warning : unknown method using classic Sinkhorn Knopp')
+ warnings.warn('Unknown method using classic Sinkhorn Knopp')
def sink():
return sinkhorn_knopp(a, b, M, reg, alpha, **kwargs)
@@ -256,7 +233,7 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000,
- a and b are source and target weights
- KL is the Kullback-Leibler divergence
- The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
+ The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
Parameters
@@ -306,6 +283,7 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000,
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [23] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015
See Also
--------
@@ -368,7 +346,7 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000,
or np.any(np.isinf(u)) or np.any(np.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
- print('Warning: numerical errors at iteration', cpt)
+ warnings.warn('Numerical errors at iteration', cpt)
u = uprev
v = vprev
break