summaryrefslogtreecommitdiff
path: root/ot/partial.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/partial.py')
-rwxr-xr-xot/partial.py122
1 files changed, 83 insertions, 39 deletions
diff --git a/ot/partial.py b/ot/partial.py
index 0a9e450..bf4119d 100755
--- a/ot/partial.py
+++ b/ot/partial.py
@@ -8,6 +8,8 @@ Partial OT solvers
import numpy as np
from .lp import emd
+from .backend import get_backend
+from .utils import list_to_array
def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
@@ -114,14 +116,22 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
ot.partial.partial_wasserstein : Partial Wasserstein with fixed mass
"""
- if np.sum(a) > 1 or np.sum(b) > 1:
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(a, b, M)
+
+ if nx.sum(a) > 1 + 1e-15 or nx.sum(b) > 1 + 1e-15: # 1e-15 for numerical errors
raise ValueError("Problem infeasible. Check that a and b are in the "
"simplex")
if reg_m is None:
- reg_m = np.max(M) + 1
- if reg_m < -np.max(M):
- return np.zeros((len(a), len(b)))
+ reg_m = float(nx.max(M)) + 1
+ if reg_m < -nx.max(M):
+ return nx.zeros((len(a), len(b)), type_as=M)
+
+ a0, b0, M0 = a, b, M
+ # convert to humpy
+ a, b, M = nx.to_numpy(a, b, M)
eps = 1e-20
M = np.asarray(M, dtype=np.float64)
@@ -149,10 +159,16 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
gamma = np.zeros((len(a), len(b)))
gamma[np.ix_(idx_x, idx_y)] = gamma_extended[:-nb_dummies, :-nb_dummies]
+ # convert back to backend
+ gamma = nx.from_numpy(gamma, type_as=M0)
+
if log_emd['warning'] is not None:
raise ValueError("Error in the EMD resolution: try to increase the"
" number of dummy points")
- log_emd['cost'] = np.sum(gamma * M)
+ log_emd['cost'] = nx.sum(gamma * M0)
+ log_emd['u'] = nx.from_numpy(log_emd['u'], type_as=a0)
+ log_emd['v'] = nx.from_numpy(log_emd['v'], type_as=b0)
+
if log:
return gamma, log_emd
else:
@@ -250,32 +266,52 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
entropic regularization parameter
"""
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(a, b, M)
+
+ dim_a, dim_b = M.shape
+ if len(a) == 0:
+ a = nx.ones(dim_a, type_as=a) / dim_a
+ if len(b) == 0:
+ b = nx.ones(dim_b, type_as=b) / dim_b
+
if m is None:
return partial_wasserstein_lagrange(a, b, M, log=log, **kwargs)
elif m < 0:
raise ValueError("Problem infeasible. Parameter m should be greater"
" than 0.")
- elif m > np.min((np.sum(a), np.sum(b))):
+ elif m > nx.min(nx.stack((nx.sum(a), nx.sum(b)))):
raise ValueError("Problem infeasible. Parameter m should lower or"
" equal than min(|a|_1, |b|_1).")
- b_extended = np.append(b, [(np.sum(a) - m) / nb_dummies] * nb_dummies)
- a_extended = np.append(a, [(np.sum(b) - m) / nb_dummies] * nb_dummies)
- M_extended = np.zeros((len(a_extended), len(b_extended)))
- M_extended[-nb_dummies:, -nb_dummies:] = np.max(M) * 2
- M_extended[:len(a), :len(b)] = M
+ b_extension = nx.ones(nb_dummies, type_as=b) * (nx.sum(a) - m) / nb_dummies
+ b_extended = nx.concatenate((b, b_extension))
+ a_extension = nx.ones(nb_dummies, type_as=a) * (nx.sum(b) - m) / nb_dummies
+ a_extended = nx.concatenate((a, a_extension))
+ M_extension = nx.ones((nb_dummies, nb_dummies), type_as=M) * nx.max(M) * 2
+ M_extended = nx.concatenate(
+ (nx.concatenate((M, nx.zeros((M.shape[0], M_extension.shape[1]))), axis=1),
+ nx.concatenate((nx.zeros((M_extension.shape[0], M.shape[1])), M_extension), axis=1)),
+ axis=0
+ )
gamma, log_emd = emd(a_extended, b_extended, M_extended, log=True,
**kwargs)
+
+ gamma = gamma[:len(a), :len(b)]
+
if log_emd['warning'] is not None:
raise ValueError("Error in the EMD resolution: try to increase the"
" number of dummy points")
- log_emd['partial_w_dist'] = np.sum(M * gamma[:len(a), :len(b)])
+ log_emd['partial_w_dist'] = nx.sum(M * gamma)
+ log_emd['u'] = log_emd['u'][:len(a)]
+ log_emd['v'] = log_emd['v'][:len(b)]
if log:
- return gamma[:len(a), :len(b)], log_emd
+ return gamma, log_emd
else:
- return gamma[:len(a), :len(b)]
+ return gamma
def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
@@ -360,14 +396,18 @@ def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
NeurIPS.
"""
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(a, b, M)
+
partial_gw, log_w = partial_wasserstein(a, b, M, m, nb_dummies, log=True,
**kwargs)
log_w['T'] = partial_gw
if log:
- return np.sum(partial_gw * M), log_w
+ return nx.sum(partial_gw * M), log_w
else:
- return np.sum(partial_gw * M)
+ return nx.sum(partial_gw * M)
def gwgrad_partial(C1, C2, T):
@@ -809,60 +849,64 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
ot.partial.partial_wasserstein: exact Partial Wasserstein
"""
- a = np.asarray(a, dtype=np.float64)
- b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(a, b, M)
dim_a, dim_b = M.shape
- dx = np.ones(dim_a, dtype=np.float64)
- dy = np.ones(dim_b, dtype=np.float64)
+ dx = nx.ones(dim_a, type_as=a)
+ dy = nx.ones(dim_b, type_as=b)
if len(a) == 0:
- a = np.ones(dim_a, dtype=np.float64) / dim_a
+ a = nx.ones(dim_a, type_as=a) / dim_a
if len(b) == 0:
- b = np.ones(dim_b, dtype=np.float64) / dim_b
+ b = nx.ones(dim_b, type_as=b) / dim_b
if m is None:
- m = np.min((np.sum(a), np.sum(b))) * 1.0
+ m = nx.min(nx.stack((nx.sum(a), nx.sum(b)))) * 1.0
if m < 0:
raise ValueError("Problem infeasible. Parameter m should be greater"
" than 0.")
- if m > np.min((np.sum(a), np.sum(b))):
+ if m > nx.min(nx.stack((nx.sum(a), nx.sum(b)))):
raise ValueError("Problem infeasible. Parameter m should lower or"
" equal than min(|a|_1, |b|_1).")
log_e = {'err': []}
- # Next 3 lines equivalent to K=np.exp(-M/reg), but faster to compute
- K = np.empty(M.shape, dtype=M.dtype)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
- np.multiply(K, m / np.sum(K), out=K)
+ if type(a) == type(b) == type(M) == np.ndarray:
+ # Next 3 lines equivalent to K=nx.exp(-M/reg), but faster to compute
+ K = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+ np.multiply(K, m / np.sum(K), out=K)
+ else:
+ K = nx.exp(-M / reg)
+ K = K * m / nx.sum(K)
err, cpt = 1, 0
- q1 = np.ones(K.shape)
- q2 = np.ones(K.shape)
- q3 = np.ones(K.shape)
+ q1 = nx.ones(K.shape, type_as=K)
+ q2 = nx.ones(K.shape, type_as=K)
+ q3 = nx.ones(K.shape, type_as=K)
while (err > stopThr and cpt < numItermax):
Kprev = K
K = K * q1
- K1 = np.dot(np.diag(np.minimum(a / np.sum(K, axis=1), dx)), K)
+ K1 = nx.dot(nx.diag(nx.minimum(a / nx.sum(K, axis=1), dx)), K)
q1 = q1 * Kprev / K1
K1prev = K1
K1 = K1 * q2
- K2 = np.dot(K1, np.diag(np.minimum(b / np.sum(K1, axis=0), dy)))
+ K2 = nx.dot(K1, nx.diag(nx.minimum(b / nx.sum(K1, axis=0), dy)))
q2 = q2 * K1prev / K2
K2prev = K2
K2 = K2 * q3
- K = K2 * (m / np.sum(K2))
+ K = K2 * (m / nx.sum(K2))
q3 = q3 * K2prev / K
- if np.any(np.isnan(K)) or np.any(np.isinf(K)):
+ if nx.any(nx.isnan(K)) or nx.any(nx.isinf(K)):
print('Warning: numerical errors at iteration', cpt)
break
if cpt % 10 == 0:
- err = np.linalg.norm(Kprev - K)
+ err = nx.norm(Kprev - K)
if log:
log_e['err'].append(err)
if verbose:
@@ -872,7 +916,7 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
print('{:5d}|{:8e}|'.format(cpt, err))
cpt = cpt + 1
- log_e['partial_w_dist'] = np.sum(M * K)
+ log_e['partial_w_dist'] = nx.sum(M * K)
if log:
return K, log_e
else: