summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHicham Janati <hicham.janati@inria.fr>2019-09-03 17:26:30 +0200
committerHicham Janati <hicham.janati@inria.fr>2019-09-03 17:26:30 +0200
commit7efea812ad0b1c7e3783397dbd8f3ad802fb7ac2 (patch)
tree3ae039a5f932bc406f5c8f5c229f7a31750f94d0
parentc7269d3fc72c679711699a9df7b5670b0dd176b0 (diff)
same for unbalanced
-rw-r--r--ot/unbalanced.py102
1 files changed, 51 insertions, 51 deletions
diff --git a/ot/unbalanced.py b/ot/unbalanced.py
index 3f71d28..25e4cf5 100644
--- a/ot/unbalanced.py
+++ b/ot/unbalanced.py
@@ -120,23 +120,23 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000,
"""
if method.lower() == 'sinkhorn':
- return _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
- numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
elif method.lower() == 'sinkhorn_stabilized':
- return _sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m,
- numItermax=numItermax,
- stopThr=stopThr,
- verbose=verbose,
- log=log, **kwargs)
+ return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr,
+ verbose=verbose,
+ log=log, **kwargs)
elif method.lower() in ['sinkhorn_reg_scaling']:
warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
- return _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
- numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
@@ -241,29 +241,29 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
if len(b.shape) < 2:
b = b[:, None]
if method.lower() == 'sinkhorn':
- return _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
- numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
elif method.lower() == 'sinkhorn_stabilized':
- return _sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m,
- numItermax=numItermax,
- stopThr=stopThr,
- verbose=verbose,
- log=log, **kwargs)
+ return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr,
+ verbose=verbose,
+ log=log, **kwargs)
elif method.lower() in ['sinkhorn_reg_scaling']:
warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
- return _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
- numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
else:
raise ValueError('Unknown method %s.' % method)
-def _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
- stopThr=1e-6, verbose=False, log=False, **kwargs):
+def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
+ stopThr=1e-6, verbose=False, log=False, **kwargs):
r"""
Solve the entropic regularization unbalanced optimal transport problem and return the loss
@@ -300,7 +300,7 @@ def _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshol on error (> 0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -439,9 +439,9 @@ def _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
return u[:, None] * K * v[None, :]
-def _sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000,
- stopThr=1e-6, verbose=False, log=False,
- **kwargs):
+def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000,
+ stopThr=1e-6, verbose=False, log=False,
+ **kwargs):
r"""
Solve the entropic regularization unbalanced optimal transport
problem and return the loss
@@ -653,9 +653,9 @@ def _sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=100
return ot_matrix
-def _barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
- numItermax=1000, stopThr=1e-6,
- verbose=False, log=False):
+def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
+ numItermax=1000, stopThr=1e-6,
+ verbose=False, log=False):
r"""Compute the entropic unbalanced wasserstein barycenter of A with stabilization.
The function solves the following optimization problem:
@@ -804,9 +804,9 @@ def _barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
return q
-def _barycenter_unbalanced(A, M, reg, reg_m, weights=None,
- numItermax=1000, stopThr=1e-6,
- verbose=False, log=False):
+def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
+ numItermax=1000, stopThr=1e-6,
+ verbose=False, log=False):
r"""Compute the entropic unbalanced wasserstein barycenter of A.
The function solves the following optimization problem with a
@@ -1001,22 +1001,22 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None,
"""
if method.lower() == 'sinkhorn':
- return _barycenter_unbalanced(A, M, reg, reg_m,
- numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
+ return barycenter_unbalanced_sinkhorn(A, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
elif method.lower() == 'sinkhorn_stabilized':
- return _barycenter_unbalanced_stabilized(A, M, reg, reg_m,
- numItermax=numItermax,
- stopThr=stopThr,
- verbose=verbose,
- log=log, **kwargs)
+ return barycenter_unbalanced_stabilized(A, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr,
+ verbose=verbose,
+ log=log, **kwargs)
elif method.lower() in ['sinkhorn_reg_scaling']:
warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
- return _barycenter_unbalanced(A, M, reg, reg_m,
- numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
+ return barycenter_unbalanced(A, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
else:
raise ValueError("Unknown method '%s'." % method)