diff options
Diffstat (limited to 'ot/partial.py')
-rwxr-xr-x | ot/partial.py | 122 |
1 files changed, 83 insertions, 39 deletions
diff --git a/ot/partial.py b/ot/partial.py index 0a9e450..bf4119d 100755 --- a/ot/partial.py +++ b/ot/partial.py @@ -8,6 +8,8 @@ Partial OT solvers import numpy as np from .lp import emd +from .backend import get_backend +from .utils import list_to_array def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, @@ -114,14 +116,22 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, ot.partial.partial_wasserstein : Partial Wasserstein with fixed mass """ - if np.sum(a) > 1 or np.sum(b) > 1: + a, b, M = list_to_array(a, b, M) + + nx = get_backend(a, b, M) + + if nx.sum(a) > 1 + 1e-15 or nx.sum(b) > 1 + 1e-15: # 1e-15 for numerical errors raise ValueError("Problem infeasible. Check that a and b are in the " "simplex") if reg_m is None: - reg_m = np.max(M) + 1 - if reg_m < -np.max(M): - return np.zeros((len(a), len(b))) + reg_m = float(nx.max(M)) + 1 + if reg_m < -nx.max(M): + return nx.zeros((len(a), len(b)), type_as=M) + + a0, b0, M0 = a, b, M + # convert to humpy + a, b, M = nx.to_numpy(a, b, M) eps = 1e-20 M = np.asarray(M, dtype=np.float64) @@ -149,10 +159,16 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, gamma = np.zeros((len(a), len(b))) gamma[np.ix_(idx_x, idx_y)] = gamma_extended[:-nb_dummies, :-nb_dummies] + # convert back to backend + gamma = nx.from_numpy(gamma, type_as=M0) + if log_emd['warning'] is not None: raise ValueError("Error in the EMD resolution: try to increase the" " number of dummy points") - log_emd['cost'] = np.sum(gamma * M) + log_emd['cost'] = nx.sum(gamma * M0) + log_emd['u'] = nx.from_numpy(log_emd['u'], type_as=a0) + log_emd['v'] = nx.from_numpy(log_emd['v'], type_as=b0) + if log: return gamma, log_emd else: @@ -250,32 +266,52 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): entropic regularization parameter """ + a, b, M = list_to_array(a, b, M) + + nx = get_backend(a, b, M) + + dim_a, dim_b = M.shape + if len(a) == 0: + a = nx.ones(dim_a, type_as=a) / dim_a + if len(b) == 0: + b = nx.ones(dim_b, type_as=b) / dim_b + if m is None: return partial_wasserstein_lagrange(a, b, M, log=log, **kwargs) elif m < 0: raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") - elif m > np.min((np.sum(a), np.sum(b))): + elif m > nx.min(nx.stack((nx.sum(a), nx.sum(b)))): raise ValueError("Problem infeasible. Parameter m should lower or" " equal than min(|a|_1, |b|_1).") - b_extended = np.append(b, [(np.sum(a) - m) / nb_dummies] * nb_dummies) - a_extended = np.append(a, [(np.sum(b) - m) / nb_dummies] * nb_dummies) - M_extended = np.zeros((len(a_extended), len(b_extended))) - M_extended[-nb_dummies:, -nb_dummies:] = np.max(M) * 2 - M_extended[:len(a), :len(b)] = M + b_extension = nx.ones(nb_dummies, type_as=b) * (nx.sum(a) - m) / nb_dummies + b_extended = nx.concatenate((b, b_extension)) + a_extension = nx.ones(nb_dummies, type_as=a) * (nx.sum(b) - m) / nb_dummies + a_extended = nx.concatenate((a, a_extension)) + M_extension = nx.ones((nb_dummies, nb_dummies), type_as=M) * nx.max(M) * 2 + M_extended = nx.concatenate( + (nx.concatenate((M, nx.zeros((M.shape[0], M_extension.shape[1]))), axis=1), + nx.concatenate((nx.zeros((M_extension.shape[0], M.shape[1])), M_extension), axis=1)), + axis=0 + ) gamma, log_emd = emd(a_extended, b_extended, M_extended, log=True, **kwargs) + + gamma = gamma[:len(a), :len(b)] + if log_emd['warning'] is not None: raise ValueError("Error in the EMD resolution: try to increase the" " number of dummy points") - log_emd['partial_w_dist'] = np.sum(M * gamma[:len(a), :len(b)]) + log_emd['partial_w_dist'] = nx.sum(M * gamma) + log_emd['u'] = log_emd['u'][:len(a)] + log_emd['v'] = log_emd['v'][:len(b)] if log: - return gamma[:len(a), :len(b)], log_emd + return gamma, log_emd else: - return gamma[:len(a), :len(b)] + return gamma def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): @@ -360,14 +396,18 @@ def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): NeurIPS. """ + a, b, M = list_to_array(a, b, M) + + nx = get_backend(a, b, M) + partial_gw, log_w = partial_wasserstein(a, b, M, m, nb_dummies, log=True, **kwargs) log_w['T'] = partial_gw if log: - return np.sum(partial_gw * M), log_w + return nx.sum(partial_gw * M), log_w else: - return np.sum(partial_gw * M) + return nx.sum(partial_gw * M) def gwgrad_partial(C1, C2, T): @@ -809,60 +849,64 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000, ot.partial.partial_wasserstein: exact Partial Wasserstein """ - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + a, b, M = list_to_array(a, b, M) + + nx = get_backend(a, b, M) dim_a, dim_b = M.shape - dx = np.ones(dim_a, dtype=np.float64) - dy = np.ones(dim_b, dtype=np.float64) + dx = nx.ones(dim_a, type_as=a) + dy = nx.ones(dim_b, type_as=b) if len(a) == 0: - a = np.ones(dim_a, dtype=np.float64) / dim_a + a = nx.ones(dim_a, type_as=a) / dim_a if len(b) == 0: - b = np.ones(dim_b, dtype=np.float64) / dim_b + b = nx.ones(dim_b, type_as=b) / dim_b if m is None: - m = np.min((np.sum(a), np.sum(b))) * 1.0 + m = nx.min(nx.stack((nx.sum(a), nx.sum(b)))) * 1.0 if m < 0: raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") - if m > np.min((np.sum(a), np.sum(b))): + if m > nx.min(nx.stack((nx.sum(a), nx.sum(b)))): raise ValueError("Problem infeasible. Parameter m should lower or" " equal than min(|a|_1, |b|_1).") log_e = {'err': []} - # Next 3 lines equivalent to K=np.exp(-M/reg), but faster to compute - K = np.empty(M.shape, dtype=M.dtype) - np.divide(M, -reg, out=K) - np.exp(K, out=K) - np.multiply(K, m / np.sum(K), out=K) + if type(a) == type(b) == type(M) == np.ndarray: + # Next 3 lines equivalent to K=nx.exp(-M/reg), but faster to compute + K = np.empty(M.shape, dtype=M.dtype) + np.divide(M, -reg, out=K) + np.exp(K, out=K) + np.multiply(K, m / np.sum(K), out=K) + else: + K = nx.exp(-M / reg) + K = K * m / nx.sum(K) err, cpt = 1, 0 - q1 = np.ones(K.shape) - q2 = np.ones(K.shape) - q3 = np.ones(K.shape) + q1 = nx.ones(K.shape, type_as=K) + q2 = nx.ones(K.shape, type_as=K) + q3 = nx.ones(K.shape, type_as=K) while (err > stopThr and cpt < numItermax): Kprev = K K = K * q1 - K1 = np.dot(np.diag(np.minimum(a / np.sum(K, axis=1), dx)), K) + K1 = nx.dot(nx.diag(nx.minimum(a / nx.sum(K, axis=1), dx)), K) q1 = q1 * Kprev / K1 K1prev = K1 K1 = K1 * q2 - K2 = np.dot(K1, np.diag(np.minimum(b / np.sum(K1, axis=0), dy))) + K2 = nx.dot(K1, nx.diag(nx.minimum(b / nx.sum(K1, axis=0), dy))) q2 = q2 * K1prev / K2 K2prev = K2 K2 = K2 * q3 - K = K2 * (m / np.sum(K2)) + K = K2 * (m / nx.sum(K2)) q3 = q3 * K2prev / K - if np.any(np.isnan(K)) or np.any(np.isinf(K)): + if nx.any(nx.isnan(K)) or nx.any(nx.isinf(K)): print('Warning: numerical errors at iteration', cpt) break if cpt % 10 == 0: - err = np.linalg.norm(Kprev - K) + err = nx.norm(Kprev - K) if log: log_e['err'].append(err) if verbose: @@ -872,7 +916,7 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000, print('{:5d}|{:8e}|'.format(cpt, err)) cpt = cpt + 1 - log_e['partial_w_dist'] = np.sum(M * K) + log_e['partial_w_dist'] = nx.sum(M * K) if log: return K, log_e else: |