diff options
author | Hicham Janati <hicham.janati@inria.fr> | 2019-09-03 17:26:30 +0200 |
---|---|---|
committer | Hicham Janati <hicham.janati@inria.fr> | 2019-09-03 17:26:30 +0200 |
commit | 7efea812ad0b1c7e3783397dbd8f3ad802fb7ac2 (patch) | |
tree | 3ae039a5f932bc406f5c8f5c229f7a31750f94d0 /ot/unbalanced.py | |
parent | c7269d3fc72c679711699a9df7b5670b0dd176b0 (diff) |
same for unbalanced
Diffstat (limited to 'ot/unbalanced.py')
-rw-r--r-- | ot/unbalanced.py | 102 |
1 files changed, 51 insertions, 51 deletions
diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 3f71d28..25e4cf5 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -120,23 +120,23 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, """ if method.lower() == 'sinkhorn': - return _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) elif method.lower() == 'sinkhorn_stabilized': - return _sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, - numItermax=numItermax, - stopThr=stopThr, - verbose=verbose, - log=log, **kwargs) + return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, **kwargs) elif method.lower() in ['sinkhorn_reg_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') - return _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) else: raise ValueError("Unknown method '%s'." % method) @@ -241,29 +241,29 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', if len(b.shape) < 2: b = b[:, None] if method.lower() == 'sinkhorn': - return _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) elif method.lower() == 'sinkhorn_stabilized': - return _sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, - numItermax=numItermax, - stopThr=stopThr, - verbose=verbose, - log=log, **kwargs) + return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, **kwargs) elif method.lower() in ['sinkhorn_reg_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') - return _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) else: raise ValueError('Unknown method %s.' % method) -def _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, - stopThr=1e-6, verbose=False, log=False, **kwargs): +def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, + stopThr=1e-6, verbose=False, log=False, **kwargs): r""" Solve the entropic regularization unbalanced optimal transport problem and return the loss @@ -300,7 +300,7 @@ def _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshol on error (> 0) verbose : bool, optional Print information along iterations log : bool, optional @@ -439,9 +439,9 @@ def _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, return u[:, None] * K * v[None, :] -def _sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000, - stopThr=1e-6, verbose=False, log=False, - **kwargs): +def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000, + stopThr=1e-6, verbose=False, log=False, + **kwargs): r""" Solve the entropic regularization unbalanced optimal transport problem and return the loss @@ -653,9 +653,9 @@ def _sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=100 return ot_matrix -def _barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, - numItermax=1000, stopThr=1e-6, - verbose=False, log=False): +def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, + numItermax=1000, stopThr=1e-6, + verbose=False, log=False): r"""Compute the entropic unbalanced wasserstein barycenter of A with stabilization. The function solves the following optimization problem: @@ -804,9 +804,9 @@ def _barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, return q -def _barycenter_unbalanced(A, M, reg, reg_m, weights=None, - numItermax=1000, stopThr=1e-6, - verbose=False, log=False): +def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, + numItermax=1000, stopThr=1e-6, + verbose=False, log=False): r"""Compute the entropic unbalanced wasserstein barycenter of A. The function solves the following optimization problem with a @@ -1001,22 +1001,22 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, """ if method.lower() == 'sinkhorn': - return _barycenter_unbalanced(A, M, reg, reg_m, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + return barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) elif method.lower() == 'sinkhorn_stabilized': - return _barycenter_unbalanced_stabilized(A, M, reg, reg_m, - numItermax=numItermax, - stopThr=stopThr, - verbose=verbose, - log=log, **kwargs) + return barycenter_unbalanced_stabilized(A, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, **kwargs) elif method.lower() in ['sinkhorn_reg_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') - return _barycenter_unbalanced(A, M, reg, reg_m, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + return barycenter_unbalanced(A, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) else: raise ValueError("Unknown method '%s'." % method) |