summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-03-21 08:29:50 +0100
committerRémi Flamary <remi.flamary@gmail.com>2018-03-21 08:29:50 +0100
commit6fdf5de8fa27fa16d6b8910fe96eb67b7761aa0e (patch)
tree55b8d6e463b131134b4aa9d4e3013bbb77811da6 /ot/bregman.py
parent287c659ad35f5036ba2687caf73009ef455c7239 (diff)
add linear mapping test + autopep8
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py30
1 files changed, 20 insertions, 10 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index d63c51d..07b8660 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -11,7 +11,8 @@ Bregman projections for regularized OT
import numpy as np
-def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs):
+def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
+ stopThr=1e-9, verbose=False, log=False, **kwargs):
u"""
Solve the entropic regularization optimal transport problem and return the OT matrix
@@ -120,7 +121,8 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, ver
return sink()
-def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs):
+def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
+ stopThr=1e-9, verbose=False, log=False, **kwargs):
u"""
Solve the entropic regularization optimal transport problem and return the loss
@@ -233,7 +235,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, ve
return sink()
-def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs):
+def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
+ stopThr=1e-9, verbose=False, log=False, **kwargs):
"""
Solve the entropic regularization optimal transport problem and return the OT matrix
@@ -403,7 +406,8 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, l
return u.reshape((-1, 1)) * K * v.reshape((1, -1))
-def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=20, log=False, **kwargs):
+def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
+ warmstart=None, verbose=False, print_period=20, log=False, **kwargs):
"""
Solve the entropic regularization OT problem with log stabilization
@@ -526,11 +530,13 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, wa
def get_K(alpha, beta):
"""log space computation"""
- return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) / reg)
+ return np.exp(-(M - alpha.reshape((na, 1)) -
+ beta.reshape((1, nb))) / reg)
def get_Gamma(alpha, beta, u, v):
"""log space gamma computation"""
- return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) / reg + np.log(u.reshape((na, 1))) + np.log(v.reshape((1, nb))))
+ return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) /
+ reg + np.log(u.reshape((na, 1))) + np.log(v.reshape((1, nb))))
# print(np.min(K))
@@ -620,7 +626,8 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, wa
return get_Gamma(alpha, beta, u, v)
-def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInnerItermax=100, tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=10, log=False, **kwargs):
+def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInnerItermax=100,
+ tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=10, log=False, **kwargs):
"""
Solve the entropic regularization optimal transport problem with log
stabilization and epsilon scaling.
@@ -739,7 +746,8 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
def get_K(alpha, beta):
"""log space computation"""
- return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) / reg)
+ return np.exp(-(M - alpha.reshape((na, 1)) -
+ beta.reshape((1, nb))) / reg)
# print(np.min(K))
def get_reg(n): # exponential decreasing
@@ -811,7 +819,8 @@ def projC(gamma, q):
return np.multiply(gamma, q / np.maximum(np.sum(gamma, axis=0), 1e-10))
-def barycenter(A, M, reg, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False):
+def barycenter(A, M, reg, weights=None, numItermax=1000,
+ stopThr=1e-4, verbose=False, log=False):
"""Compute the entropic regularized wasserstein barycenter of distributions A
The function solves the following optimization problem:
@@ -904,7 +913,8 @@ def barycenter(A, M, reg, weights=None, numItermax=1000, stopThr=1e-4, verbose=F
return geometricBar(weights, UKv)
-def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, stopThr=1e-3, verbose=False, log=False):
+def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
+ stopThr=1e-3, verbose=False, log=False):
"""
Compute the unmixing of an observation with a given dictionary using Wasserstein distance