summaryrefslogtreecommitdiff
path: root/ot/unbalanced.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/unbalanced.py')
-rw-r--r--ot/unbalanced.py302
1 files changed, 153 insertions, 149 deletions
diff --git a/ot/unbalanced.py b/ot/unbalanced.py
index 15e180b..503cc1e 100644
--- a/ot/unbalanced.py
+++ b/ot/unbalanced.py
@@ -8,9 +8,9 @@ Regularized Unbalanced OT solvers
from __future__ import division
import warnings
-import numpy as np
-from scipy.special import logsumexp
+from .backend import get_backend
+from .utils import list_to_array
# from .utils import unif, dist
@@ -43,12 +43,12 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000,
Parameters
----------
- a : np.ndarray (dim_a,)
+ a : array-like (dim_a,)
Unnormalized histogram of dimension `dim_a`
- b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ b : array-like (dim_b,) or array-like (dim_b, n_hists)
One or multiple unnormalized histograms of dimension `dim_b`.
If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i`
- M : np.ndarray (dim_a, dim_b)
+ M : array-like (dim_a, dim_b)
loss matrix
reg : float
Entropy regularization term > 0
@@ -70,12 +70,12 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000,
Returns
-------
if n_hists == 1:
- - gamma : (dim_a, dim_b) ndarray
+ - gamma : (dim_a, dim_b) array-like
Optimal transportation matrix for the given parameters
- log : dict
log dictionary returned only if `log` is `True`
else:
- - ot_distance : (n_hists,) ndarray
+ - ot_distance : (n_hists,) array-like
the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i`
- log : dict
log dictionary returned only if `log` is `True`
@@ -172,12 +172,12 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
Parameters
----------
- a : np.ndarray (dim_a,)
+ a : array-like (dim_a,)
Unnormalized histogram of dimension `dim_a`
- b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ b : array-like (dim_b,) or array-like (dim_b, n_hists)
One or multiple unnormalized histograms of dimension `dim_b`.
If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i`
- M : np.ndarray (dim_a, dim_b)
+ M : array-like (dim_a, dim_b)
loss matrix
reg : float
Entropy regularization term > 0
@@ -198,7 +198,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
Returns
-------
- ot_distance : (n_hists,) ndarray
+ ot_distance : (n_hists,) array-like
the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i`
log : dict
log dictionary returned only if `log` is `True`
@@ -239,9 +239,10 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epslilon scaling :ref:`[9, 10] <references-sinkhorn-unbalanced2>`
"""
- b = np.asarray(b, dtype=np.float64)
+ b = list_to_array(b)
if len(b.shape) < 2:
b = b[:, None]
+
if method.lower() == 'sinkhorn':
return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
numItermax=numItermax,
@@ -291,12 +292,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
Parameters
----------
- a : np.ndarray (dim_a,)
+ a : array-like (dim_a,)
Unnormalized histogram of dimension `dim_a`
- b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ b : array-like (dim_b,) or array-like (dim_b, n_hists)
One or multiple unnormalized histograms of dimension `dim_b`
If many, compute all the OT distances (a, b_i)
- M : np.ndarray (dim_a, dim_b)
+ M : array-like (dim_a, dim_b)
loss matrix
reg : float
Entropy regularization term > 0
@@ -315,12 +316,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
Returns
-------
if n_hists == 1:
- - gamma : (dim_a, dim_b) ndarray
+ - gamma : (dim_a, dim_b) array-like
Optimal transportation matrix for the given parameters
- log : dict
log dictionary returned only if `log` is `True`
else:
- - ot_distance : (n_hists,) ndarray
+ - ot_distance : (n_hists,) array-like
the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i`
- log : dict
log dictionary returned only if `log` is `True`
@@ -354,17 +355,15 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
ot.optim.cg : General regularized OT
"""
-
- a = np.asarray(a, dtype=np.float64)
- b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ M, a, b = list_to_array(M, a, b)
+ nx = get_backend(M, a, b)
dim_a, dim_b = M.shape
if len(a) == 0:
- a = np.ones(dim_a, dtype=np.float64) / dim_a
+ a = nx.ones(dim_a, type_as=M) / dim_a
if len(b) == 0:
- b = np.ones(dim_b, dtype=np.float64) / dim_b
+ b = nx.ones(dim_b, type_as=M) / dim_b
if len(b.shape) > 1:
n_hists = b.shape[1]
@@ -377,17 +376,14 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
# we assume that no distances are null except those of the diagonal of
# distances
if n_hists:
- u = np.ones((dim_a, 1)) / dim_a
- v = np.ones((dim_b, n_hists)) / dim_b
+ u = nx.ones((dim_a, 1), type_as=M) / dim_a
+ v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
a = a.reshape(dim_a, 1)
else:
- u = np.ones(dim_a) / dim_a
- v = np.ones(dim_b) / dim_b
+ u = nx.ones(dim_a, type_as=M) / dim_a
+ v = nx.ones(dim_b, type_as=M) / dim_b
- # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
- K = np.empty(M.shape, dtype=M.dtype)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
+ K = nx.exp(M / (-reg))
fi = reg_m / (reg_m + reg)
@@ -397,14 +393,14 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
uprev = u
vprev = v
- Kv = K.dot(v)
+ Kv = nx.dot(K, v)
u = (a / Kv) ** fi
- Ktu = K.T.dot(u)
+ Ktu = nx.dot(K.T, u)
v = (b / Ktu) ** fi
- if (np.any(Ktu == 0.)
- or np.any(np.isnan(u)) or np.any(np.isnan(v))
- or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ if (nx.any(Ktu == 0.)
+ or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
+ or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
warnings.warn('Numerical errors at iteration %s' % i)
@@ -412,8 +408,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
v = vprev
break
- err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.)
- err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.)
+ err_u = nx.max(nx.abs(u - uprev)) / max(
+ nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1.
+ )
+ err_v = nx.max(nx.abs(v - vprev)) / max(
+ nx.max(nx.abs(v)), nx.max(nx.abs(vprev)), 1.
+ )
err = 0.5 * (err_u + err_v)
if log:
log['err'].append(err)
@@ -426,11 +426,11 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
break
if log:
- log['logu'] = np.log(u + 1e-300)
- log['logv'] = np.log(v + 1e-300)
+ log['logu'] = nx.log(u + 1e-300)
+ log['logv'] = nx.log(v + 1e-300)
if n_hists: # return only loss
- res = np.einsum('ik,ij,jk,ij->k', u, K, v, M)
+ res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M)
if log:
return res, log
else:
@@ -475,12 +475,12 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000
Parameters
----------
- a : np.ndarray (dim_a,)
+ a : array-like (dim_a,)
Unnormalized histogram of dimension `dim_a`
- b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ b : array-like (dim_b,) or array-like (dim_b, n_hists)
One or multiple unnormalized histograms of dimension `dim_b`.
If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i`
- M : np.ndarray (dim_a, dim_b)
+ M : array-like (dim_a, dim_b)
loss matrix
reg : float
Entropy regularization term > 0
@@ -501,12 +501,12 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000
Returns
-------
if n_hists == 1:
- - gamma : (dim_a, dim_b) ndarray
+ - gamma : (dim_a, dim_b) array-like
Optimal transportation matrix for the given parameters
- log : dict
log dictionary returned only if `log` is `True`
else:
- - ot_distance : (n_hists,) ndarray
+ - ot_distance : (n_hists,) array-like
the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i`
- log : dict
log dictionary returned only if `log` is `True`
@@ -538,17 +538,15 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000
ot.optim.cg : General regularized OT
"""
-
- a = np.asarray(a, dtype=np.float64)
- b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ a, b, M = list_to_array(a, b, M)
+ nx = get_backend(M, a, b)
dim_a, dim_b = M.shape
if len(a) == 0:
- a = np.ones(dim_a, dtype=np.float64) / dim_a
+ a = nx.ones(dim_a, type_as=M) / dim_a
if len(b) == 0:
- b = np.ones(dim_b, dtype=np.float64) / dim_b
+ b = nx.ones(dim_b, type_as=M) / dim_b
if len(b.shape) > 1:
n_hists = b.shape[1]
@@ -561,56 +559,52 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000
# we assume that no distances are null except those of the diagonal of
# distances
if n_hists:
- u = np.ones((dim_a, n_hists)) / dim_a
- v = np.ones((dim_b, n_hists)) / dim_b
+ u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
+ v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
a = a.reshape(dim_a, 1)
else:
- u = np.ones(dim_a) / dim_a
- v = np.ones(dim_b) / dim_b
+ u = nx.ones(dim_a, type_as=M) / dim_a
+ v = nx.ones(dim_b, type_as=M) / dim_b
# print(reg)
- # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
- K = np.empty(M.shape, dtype=M.dtype)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
+ K = nx.exp(-M / reg)
fi = reg_m / (reg_m + reg)
cpt = 0
err = 1.
- alpha = np.zeros(dim_a)
- beta = np.zeros(dim_b)
+ alpha = nx.zeros(dim_a, type_as=M)
+ beta = nx.zeros(dim_b, type_as=M)
while (err > stopThr and cpt < numItermax):
uprev = u
vprev = v
- Kv = K.dot(v)
- f_alpha = np.exp(- alpha / (reg + reg_m))
- f_beta = np.exp(- beta / (reg + reg_m))
+ Kv = nx.dot(K, v)
+ f_alpha = nx.exp(- alpha / (reg + reg_m))
+ f_beta = nx.exp(- beta / (reg + reg_m))
if n_hists:
f_alpha = f_alpha[:, None]
f_beta = f_beta[:, None]
u = ((a / (Kv + 1e-16)) ** fi) * f_alpha
- Ktu = K.T.dot(u)
+ Ktu = nx.dot(K.T, u)
v = ((b / (Ktu + 1e-16)) ** fi) * f_beta
absorbing = False
- if (u > tau).any() or (v > tau).any():
+ if nx.any(u > tau) or nx.any(v > tau):
absorbing = True
if n_hists:
- alpha = alpha + reg * np.log(np.max(u, 1))
- beta = beta + reg * np.log(np.max(v, 1))
+ alpha = alpha + reg * nx.log(nx.max(u, 1))
+ beta = beta + reg * nx.log(nx.max(v, 1))
else:
- alpha = alpha + reg * np.log(np.max(u))
- beta = beta + reg * np.log(np.max(v))
- K = np.exp((alpha[:, None] + beta[None, :] -
- M) / reg)
- v = np.ones_like(v)
- Kv = K.dot(v)
-
- if (np.any(Ktu == 0.)
- or np.any(np.isnan(u)) or np.any(np.isnan(v))
- or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ alpha = alpha + reg * nx.log(nx.max(u))
+ beta = beta + reg * nx.log(nx.max(v))
+ K = nx.exp((alpha[:, None] + beta[None, :] - M) / reg)
+ v = nx.ones(v.shape, type_as=v)
+ Kv = nx.dot(K, v)
+
+ if (nx.any(Ktu == 0.)
+ or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
+ or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
warnings.warn('Numerical errors at iteration %s' % cpt)
@@ -620,8 +614,9 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000
if (cpt % 10 == 0 and not absorbing) or cpt == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- err = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(),
- 1.)
+ err = nx.max(nx.abs(u - uprev)) / max(
+ nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1.
+ )
if log:
log['err'].append(err)
if verbose:
@@ -636,25 +631,30 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000
"Try a larger entropy `reg` or a lower mass `reg_m`." +
"Or a larger absorption threshold `tau`.")
if n_hists:
- logu = alpha[:, None] / reg + np.log(u)
- logv = beta[:, None] / reg + np.log(v)
+ logu = alpha[:, None] / reg + nx.log(u)
+ logv = beta[:, None] / reg + nx.log(v)
else:
- logu = alpha / reg + np.log(u)
- logv = beta / reg + np.log(v)
+ logu = alpha / reg + nx.log(u)
+ logv = beta / reg + nx.log(v)
if log:
log['logu'] = logu
log['logv'] = logv
if n_hists: # return only loss
- res = logsumexp(np.log(M + 1e-100)[:, :, None] + logu[:, None, :] +
- logv[None, :, :] - M[:, :, None] / reg, axis=(0, 1))
- res = np.exp(res)
+ res = nx.logsumexp(
+ nx.log(M + 1e-100)[:, :, None]
+ + logu[:, None, :]
+ + logv[None, :, :]
+ - M[:, :, None] / reg,
+ axis=(0, 1)
+ )
+ res = nx.exp(res)
if log:
return res, log
else:
return res
else: # return OT matrix
- ot_matrix = np.exp(logu[:, None] + logv[None, :] - M / reg)
+ ot_matrix = nx.exp(logu[:, None] + logv[None, :] - M / reg)
if log:
return ot_matrix, log
else:
@@ -683,9 +683,9 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
Parameters
----------
- A : np.ndarray (dim, n_hists)
+ A : array-like (dim, n_hists)
`n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim`
- M : np.ndarray (dim, dim)
+ M : array-like (dim, dim)
ground metric matrix for OT.
reg : float
Entropy regularization term > 0
@@ -693,7 +693,7 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
Marginal relaxation term > 0
tau : float
Stabilization threshold for log domain absorption.
- weights : np.ndarray (n_hists,) optional
+ weights : array-like (n_hists,) optional
Weight of each distribution (barycentric coodinates)
If None, uniform weights are used.
numItermax : int, optional
@@ -708,7 +708,7 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
Returns
-------
- a : (dim,) ndarray
+ a : (dim,) array-like
Unbalanced Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
@@ -726,9 +726,12 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
"""
+ A, M = list_to_array(A, M)
+ nx = get_backend(A, M)
+
dim, n_hists = A.shape
if weights is None:
- weights = np.ones(n_hists) / n_hists
+ weights = nx.ones(n_hists, type_as=A) / n_hists
else:
assert(len(weights) == A.shape[1])
@@ -737,47 +740,43 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
fi = reg_m / (reg_m + reg)
- u = np.ones((dim, n_hists)) / dim
- v = np.ones((dim, n_hists)) / dim
+ u = nx.ones((dim, n_hists), type_as=A) / dim
+ v = nx.ones((dim, n_hists), type_as=A) / dim
# print(reg)
- # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
- K = np.empty(M.shape, dtype=M.dtype)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
+ K = nx.exp(-M / reg)
fi = reg_m / (reg_m + reg)
cpt = 0
err = 1.
- alpha = np.zeros(dim)
- beta = np.zeros(dim)
- q = np.ones(dim) / dim
+ alpha = nx.zeros(dim, type_as=A)
+ beta = nx.zeros(dim, type_as=A)
+ q = nx.ones(dim, type_as=A) / dim
for i in range(numItermax):
- qprev = q.copy()
- Kv = K.dot(v)
- f_alpha = np.exp(- alpha / (reg + reg_m))
- f_beta = np.exp(- beta / (reg + reg_m))
+ qprev = nx.copy(q)
+ Kv = nx.dot(K, v)
+ f_alpha = nx.exp(- alpha / (reg + reg_m))
+ f_beta = nx.exp(- beta / (reg + reg_m))
f_alpha = f_alpha[:, None]
f_beta = f_beta[:, None]
u = ((A / (Kv + 1e-16)) ** fi) * f_alpha
- Ktu = K.T.dot(u)
+ Ktu = nx.dot(K.T, u)
q = (Ktu ** (1 - fi)) * f_beta
- q = q.dot(weights) ** (1 / (1 - fi))
+ q = nx.dot(q, weights) ** (1 / (1 - fi))
Q = q[:, None]
v = ((Q / (Ktu + 1e-16)) ** fi) * f_beta
absorbing = False
- if (u > tau).any() or (v > tau).any():
+ if nx.any(u > tau) or nx.any(v > tau):
absorbing = True
- alpha = alpha + reg * np.log(np.max(u, 1))
- beta = beta + reg * np.log(np.max(v, 1))
- K = np.exp((alpha[:, None] + beta[None, :] -
- M) / reg)
- v = np.ones_like(v)
- Kv = K.dot(v)
- if (np.any(Ktu == 0.)
- or np.any(np.isnan(u)) or np.any(np.isnan(v))
- or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ alpha = alpha + reg * nx.log(nx.max(u, 1))
+ beta = beta + reg * nx.log(nx.max(v, 1))
+ K = nx.exp((alpha[:, None] + beta[None, :] - M) / reg)
+ v = nx.ones(v.shape, type_as=v)
+ Kv = nx.dot(K, v)
+ if (nx.any(Ktu == 0.)
+ or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
+ or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
warnings.warn('Numerical errors at iteration %s' % cpt)
@@ -786,8 +785,9 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
if (i % 10 == 0 and not absorbing) or i == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- err = abs(q - qprev).max() / max(abs(q).max(),
- abs(qprev).max(), 1.)
+ err = nx.max(nx.abs(q - qprev)) / max(
+ nx.max(nx.abs(q)), nx.max(nx.abs(qprev)), 1.
+ )
if log:
log['err'].append(err)
if verbose:
@@ -804,8 +804,8 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
"Or a larger absorption threshold `tau`.")
if log:
log['niter'] = i
- log['logu'] = np.log(u + 1e-300)
- log['logv'] = np.log(v + 1e-300)
+ log['logu'] = nx.log(u + 1e-300)
+ log['logv'] = nx.log(v + 1e-300)
return q, log
else:
return q
@@ -833,15 +833,15 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
Parameters
----------
- A : np.ndarray (dim, n_hists)
+ A : array-like (dim, n_hists)
`n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim`
- M : np.ndarray (dim, dim)
+ M : array-like (dim, dim)
ground metric matrix for OT.
reg : float
Entropy regularization term > 0
reg_m: float
Marginal relaxation term > 0
- weights : np.ndarray (n_hists,) optional
+ weights : array-like (n_hists,) optional
Weight of each distribution (barycentric coodinates)
If None, uniform weights are used.
numItermax : int, optional
@@ -856,7 +856,7 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
Returns
-------
- a : (dim,) ndarray
+ a : (dim,) array-like
Unbalanced Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
@@ -874,40 +874,43 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
"""
+ A, M = list_to_array(A, M)
+ nx = get_backend(A, M)
+
dim, n_hists = A.shape
if weights is None:
- weights = np.ones(n_hists) / n_hists
+ weights = nx.ones(n_hists, type_as=A) / n_hists
else:
assert(len(weights) == A.shape[1])
if log:
log = {'err': []}
- K = np.exp(- M / reg)
+ K = nx.exp(-M / reg)
fi = reg_m / (reg_m + reg)
- v = np.ones((dim, n_hists))
- u = np.ones((dim, 1))
- q = np.ones(dim)
+ v = nx.ones((dim, n_hists), type_as=A)
+ u = nx.ones((dim, 1), type_as=A)
+ q = nx.ones(dim, type_as=A)
err = 1.
for i in range(numItermax):
- uprev = u.copy()
- vprev = v.copy()
- qprev = q.copy()
+ uprev = nx.copy(u)
+ vprev = nx.copy(v)
+ qprev = nx.copy(q)
- Kv = K.dot(v)
+ Kv = nx.dot(K, v)
u = (A / Kv) ** fi
- Ktu = K.T.dot(u)
- q = ((Ktu ** (1 - fi)).dot(weights))
+ Ktu = nx.dot(K.T, u)
+ q = nx.dot(Ktu ** (1 - fi), weights)
q = q ** (1 / (1 - fi))
Q = q[:, None]
v = (Q / Ktu) ** fi
- if (np.any(Ktu == 0.)
- or np.any(np.isnan(u)) or np.any(np.isnan(v))
- or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ if (nx.any(Ktu == 0.)
+ or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
+ or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
warnings.warn('Numerical errors at iteration %s' % i)
@@ -916,8 +919,9 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
q = qprev
break
# compute change in barycenter
- err = abs(q - qprev).max()
- err /= max(abs(q).max(), abs(qprev).max(), 1.)
+ err = nx.max(nx.abs(q - qprev)) / max(
+ nx.max(nx.abs(q)), nx.max(nx.abs(qprev)), 1.0
+ )
if log:
log['err'].append(err)
# if barycenter did not change + at least 10 iterations - stop
@@ -932,8 +936,8 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
if log:
log['niter'] = i
- log['logu'] = np.log(u + 1e-300)
- log['logv'] = np.log(v + 1e-300)
+ log['logu'] = nx.log(u + 1e-300)
+ log['logv'] = nx.log(v + 1e-300)
return q, log
else:
return q
@@ -961,15 +965,15 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None,
Parameters
----------
- A : np.ndarray (dim, n_hists)
+ A : array-like (dim, n_hists)
`n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim`
- M : np.ndarray (dim, dim)
+ M : array-like (dim, dim)
ground metric matrix for OT.
reg : float
Entropy regularization term > 0
reg_m: float
Marginal relaxation term > 0
- weights : np.ndarray (n_hists,) optional
+ weights : array-like (n_hists,) optional
Weight of each distribution (barycentric coodinates)
If None, uniform weights are used.
numItermax : int, optional
@@ -984,7 +988,7 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None,
Returns
-------
- a : (dim,) ndarray
+ a : (dim,) array-like
Unbalanced Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters