summaryrefslogtreecommitdiff
path: root/ot/unbalanced.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2019-07-02 16:44:41 +0200
committerRémi Flamary <remi.flamary@gmail.com>2019-07-02 16:44:41 +0200
commit4053866fb2003d6a84353f6a7b209418608c25eb (patch)
tree387a454400c3c853a2b2d12fb51b725ce105a7c2 /ot/unbalanced.py
parentef00ce42616fe7adf747c23a5590a83b62171a36 (diff)
parent8b3927bb5e8935c3dbddf054f054dc0c036fbdfe (diff)
Merge branch 'master' into doc_modules
Diffstat (limited to 'ot/unbalanced.py')
-rw-r--r--ot/unbalanced.py19
1 files changed, 10 insertions, 9 deletions
diff --git a/ot/unbalanced.py b/ot/unbalanced.py
index bad12d6..50ec03c 100644
--- a/ot/unbalanced.py
+++ b/ot/unbalanced.py
@@ -6,6 +6,7 @@ Regularized Unbalanced OT
# Author: Hicham Janati <hicham.janati@inria.fr>
# License: MIT License
+from __future__ import division
import warnings
import numpy as np
# from .utils import unif, dist
@@ -13,7 +14,7 @@ import numpy as np
def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
stopThr=1e-9, verbose=False, log=False, **kwargs):
- u"""
+ r"""
Solve the unbalanced entropic regularization optimal transport problem and return the loss
The function solves the following optimization problem:
@@ -75,7 +76,7 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
>>> M=[[0., 1.], [1., 0.]]
>>> ot.sinkhorn_unbalanced(a, b, M, 1, 1)
array([[0.51122823, 0.18807035],
- [0.18807035, 0.51122823]])
+ [0.18807035, 0.51122823]])
References
@@ -122,7 +123,7 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn',
numItermax=1000, stopThr=1e-9, verbose=False,
log=False, **kwargs):
- u"""
+ r"""
Solve the entropic regularization unbalanced optimal transport problem and return the loss
The function solves the following optimization problem:
@@ -233,7 +234,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn',
def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
stopThr=1e-9, verbose=False, log=False, **kwargs):
- """
+ r"""
Solve the entropic regularization unbalanced optimal transport problem and return the loss
The function solves the following optimization problem:
@@ -287,12 +288,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
--------
>>> import ot
- >>> a=[.5, .15]
+ >>> a=[.5, .5]
>>> b=[.5, .5]
>>> M=[[0., 1.],[1., 0.]]
- >>> ot.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.)
- array([[0.52761554, 0.22392482],
- [0.10286295, 0.32257641]])
+ >>> ot.unbalanced.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.)
+ array([[0.51122823, 0.18807035],
+ [0.18807035, 0.51122823]])
References
----------
@@ -401,7 +402,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
stopThr=1e-4, verbose=False, log=False):
- """Compute the entropic regularized unbalanced wasserstein barycenter of distributions A
+ r"""Compute the entropic regularized unbalanced wasserstein barycenter of distributions A
The function solves the following optimization problem: