diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2018-05-09 13:26:33 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-05-09 13:26:33 +0200 |
commit | 27032b6fa0f2f68af3fe4f90e5dcbb68f130a962 (patch) | |
tree | bef4b8a49cf410652e9c100274d010519acaf0f8 /ot/bregman.py | |
parent | 1ff35860db2d612748270299d7ce0037b8d40702 (diff) | |
parent | 0496e2b1b2c2f4ea2d7f313ccf58c612efaa70bf (diff) |
Merge pull request #42 from rflamary/linear_mapping
Linear mapping + tests
Diffstat (limited to 'ot/bregman.py')
-rw-r--r-- | ot/bregman.py | 30 |
1 files changed, 20 insertions, 10 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index d63c51d..07b8660 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -11,7 +11,8 @@ Bregman projections for regularized OT import numpy as np -def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs): +def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, + stopThr=1e-9, verbose=False, log=False, **kwargs): u""" Solve the entropic regularization optimal transport problem and return the OT matrix @@ -120,7 +121,8 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, ver return sink() -def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs): +def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, + stopThr=1e-9, verbose=False, log=False, **kwargs): u""" Solve the entropic regularization optimal transport problem and return the loss @@ -233,7 +235,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, ve return sink() -def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs): +def sinkhorn_knopp(a, b, M, reg, numItermax=1000, + stopThr=1e-9, verbose=False, log=False, **kwargs): """ Solve the entropic regularization optimal transport problem and return the OT matrix @@ -403,7 +406,8 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, l return u.reshape((-1, 1)) * K * v.reshape((1, -1)) -def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=20, log=False, **kwargs): +def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, + warmstart=None, verbose=False, print_period=20, log=False, **kwargs): """ Solve the entropic regularization OT problem with log stabilization @@ -526,11 +530,13 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, wa def get_K(alpha, beta): """log space computation""" - return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) / reg) + return np.exp(-(M - alpha.reshape((na, 1)) - + beta.reshape((1, nb))) / reg) def get_Gamma(alpha, beta, u, v): """log space gamma computation""" - return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) / reg + np.log(u.reshape((na, 1))) + np.log(v.reshape((1, nb)))) + return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) / + reg + np.log(u.reshape((na, 1))) + np.log(v.reshape((1, nb)))) # print(np.min(K)) @@ -620,7 +626,8 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, wa return get_Gamma(alpha, beta, u, v) -def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInnerItermax=100, tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=10, log=False, **kwargs): +def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInnerItermax=100, + tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=10, log=False, **kwargs): """ Solve the entropic regularization optimal transport problem with log stabilization and epsilon scaling. @@ -739,7 +746,8 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne def get_K(alpha, beta): """log space computation""" - return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) / reg) + return np.exp(-(M - alpha.reshape((na, 1)) - + beta.reshape((1, nb))) / reg) # print(np.min(K)) def get_reg(n): # exponential decreasing @@ -811,7 +819,8 @@ def projC(gamma, q): return np.multiply(gamma, q / np.maximum(np.sum(gamma, axis=0), 1e-10)) -def barycenter(A, M, reg, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False): +def barycenter(A, M, reg, weights=None, numItermax=1000, + stopThr=1e-4, verbose=False, log=False): """Compute the entropic regularized wasserstein barycenter of distributions A The function solves the following optimization problem: @@ -904,7 +913,8 @@ def barycenter(A, M, reg, weights=None, numItermax=1000, stopThr=1e-4, verbose=F return geometricBar(weights, UKv) -def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, stopThr=1e-3, verbose=False, log=False): +def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, + stopThr=1e-3, verbose=False, log=False): """ Compute the unmixing of an observation with a given dictionary using Wasserstein distance |