From 726e84e1e9f2832ea5ad156f62a5e3636c1fd3d3 Mon Sep 17 00:00:00 2001 From: Nathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com> Date: Fri, 6 May 2022 13:34:18 +0200 Subject: [MRG] Torch random generator not working for Cuda tensor (#373) * Solve bug * Update release file --- ot/backend.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) (limited to 'ot/backend.py') diff --git a/ot/backend.py b/ot/backend.py index 361ffba..e4b48e1 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1507,15 +1507,19 @@ class TorchBackend(Backend): def __init__(self): - self.rng_ = torch.Generator() + self.rng_ = torch.Generator("cpu") self.rng_.seed() self.__type_list__ = [torch.tensor(1, dtype=torch.float32), torch.tensor(1, dtype=torch.float64)] if torch.cuda.is_available(): + self.rng_cuda_ = torch.Generator("cuda") + self.rng_cuda_.seed() self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device='cuda')) self.__type_list__.append(torch.tensor(1, dtype=torch.float64, device='cuda')) + else: + self.rng_cuda_ = torch.Generator("cpu") from torch.autograd import Function @@ -1761,20 +1765,26 @@ class TorchBackend(Backend): def seed(self, seed=None): if isinstance(seed, int): self.rng_.manual_seed(seed) + self.rng_cuda_.manual_seed(seed) elif isinstance(seed, torch.Generator): - self.rng_ = seed + if self.device_type(seed) == "GPU": + self.rng_cuda_ = seed + else: + self.rng_ = seed else: raise ValueError("Non compatible seed : {}".format(seed)) def rand(self, *size, type_as=None): if type_as is not None: - return torch.rand(size=size, generator=self.rng_, dtype=type_as.dtype, device=type_as.device) + generator = self.rng_cuda_ if self.device_type(type_as) == "GPU" else self.rng_ + return torch.rand(size=size, generator=generator, dtype=type_as.dtype, device=type_as.device) else: return torch.rand(size=size, generator=self.rng_) def randn(self, *size, type_as=None): if type_as is not None: - return torch.randn(size=size, dtype=type_as.dtype, generator=self.rng_, device=type_as.device) + generator = self.rng_cuda_ if self.device_type(type_as) == "GPU" else self.rng_ + return torch.randn(size=size, dtype=type_as.dtype, generator=generator, device=type_as.device) else: return torch.randn(size=size, generator=self.rng_) -- cgit v1.2.3 From 0411ea22a96f9c22af30156b45c16ef39ffb520d Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Thu, 15 Dec 2022 09:28:01 +0100 Subject: [MRG] New API for OT solver (with pre-computed ground cost matrix) (#388) * new API for OT solver * use itertools for product of parameters * add tests for result class * add tests for result class * add tests for result class last time? * add sinkhorn * make partial OT bckend compatible * add TV as unbalanced flavor * better tests * make smoth backend compatible and add l2 tregularizatio to solve * add reularizedd unbalanced * add test for more complex attibutes * add test for more complex attibutes * add generic unbalaned solver and implement it for ot.solve * add entropy to possible regularization * star of documentation for ot.solv * weird new pep8 * documenttaion for function ot.solve done * pep8 * Update ot/solvers.py Co-authored-by: Alexandre Gramfort * update release file * Apply suggestions from code review Co-authored-by: Alexandre Gramfort * add test NotImplemented * pep8 * pep8gcmp pep8! * compute kl in backend * debug tensorflow kl backend Co-authored-by: Alexandre Gramfort --- RELEASES.md | 3 + ot/__init__.py | 7 +- ot/backend.py | 30 +++++ ot/partial.py | 47 +++++-- ot/smooth.py | 11 +- ot/solvers.py | 347 ++++++++++++++++++++++++++++++++++++++++++++++++ ot/unbalanced.py | 189 +++++++++++++++++++++++++- ot/utils.py | 202 +++++++++++++++++++++++++++- test/test_backend.py | 6 + test/test_partial.py | 2 + test/test_solvers.py | 133 +++++++++++++++++++ test/test_unbalanced.py | 23 ++++ test/test_utils.py | 29 ++++ 13 files changed, 1011 insertions(+), 18 deletions(-) create mode 100644 ot/solvers.py create mode 100644 test/test_solvers.py (limited to 'ot/backend.py') diff --git a/RELEASES.md b/RELEASES.md index 3bd84c1..9cfdd35 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -6,6 +6,9 @@ - Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376) - Added Free Support Sinkhorn Barycenter + example (PR #387) +- New API for OT solver using function `ot.solve` (PR #388) +- Backend version of `ot.partial` and `ot.smooth` (PR #388) + #### Closed issues diff --git a/ot/__init__.py b/ot/__init__.py index 15d8351..51eb726 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -34,6 +34,7 @@ from . import backend from . import regpath from . import weak from . import factored +from . import solvers # OT functions from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d @@ -46,7 +47,7 @@ from .gromov import (gromov_wasserstein, gromov_wasserstein2, gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) from .weak import weak_optimal_transport from .factored import factored_optimal_transport - +from .solvers import solve # utils functions from .utils import dist, unif, tic, toc, toq @@ -61,5 +62,5 @@ __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2', 'max_sliced_wasserstein_distance', 'weak_optimal_transport', - 'factored_optimal_transport', - 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath'] + 'factored_optimal_transport', 'solve', + 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers'] diff --git a/ot/backend.py b/ot/backend.py index e4b48e1..337e040 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -854,6 +854,21 @@ class Backend(): """ raise NotImplementedError() + def kl_div(self, p, q, eps=1e-16): + r""" + Computes the Kullback-Leibler divergence. + + This function follows the api from :any:`scipy.stats.entropy`. + + Parameter eps is used to avoid numerical errors and is added in the log. + + .. math:: + KL(p,q) = \sum_i p(i) \log (\frac{p(i)}{q(i)}+\epsilon) + + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.entropy.html + """ + raise NotImplementedError() + def isfinite(self, a): r""" Tests element-wise for finiteness (not infinity and not Not a Number). @@ -1158,6 +1173,9 @@ class NumpyBackend(Backend): def sqrtm(self, a): return scipy.linalg.sqrtm(a) + def kl_div(self, p, q, eps=1e-16): + return np.sum(p * np.log(p / q + eps)) + def isfinite(self, a): return np.isfinite(a) @@ -1481,6 +1499,9 @@ class JaxBackend(Backend): L, V = jnp.linalg.eigh(a) return (V * jnp.sqrt(L)[None, :]) @ V.T + def kl_div(self, p, q, eps=1e-16): + return jnp.sum(p * jnp.log(p / q + eps)) + def isfinite(self, a): return jnp.isfinite(a) @@ -1901,6 +1922,9 @@ class TorchBackend(Backend): L, V = torch.linalg.eigh(a) return (V * torch.sqrt(L)[None, :]) @ V.T + def kl_div(self, p, q, eps=1e-16): + return torch.sum(p * torch.log(p / q + eps)) + def isfinite(self, a): return torch.isfinite(a) @@ -2248,6 +2272,9 @@ class CupyBackend(Backend): # pragma: no cover L, V = cp.linalg.eigh(a) return (V * self.sqrt(L)[None, :]) @ V.T + def kl_div(self, p, q, eps=1e-16): + return cp.sum(p * cp.log(p / q + eps)) + def isfinite(self, a): return cp.isfinite(a) @@ -2608,6 +2635,9 @@ class TensorflowBackend(Backend): def sqrtm(self, a): return tf.linalg.sqrtm(a) + def kl_div(self, p, q, eps=1e-16): + return tnp.sum(p * tnp.log(p / q + eps)) + def isfinite(self, a): return tnp.isfinite(a) diff --git a/ot/partial.py b/ot/partial.py index 0a9e450..eae91c4 100755 --- a/ot/partial.py +++ b/ot/partial.py @@ -8,6 +8,8 @@ Partial OT solvers import numpy as np from .lp import emd +from .backend import get_backend +from .utils import list_to_array def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, @@ -114,14 +116,22 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, ot.partial.partial_wasserstein : Partial Wasserstein with fixed mass """ - if np.sum(a) > 1 or np.sum(b) > 1: + a, b, M = list_to_array(a, b, M) + + nx = get_backend(a, b, M) + + if nx.sum(a) > 1 or nx.sum(b) > 1: raise ValueError("Problem infeasible. Check that a and b are in the " "simplex") if reg_m is None: - reg_m = np.max(M) + 1 - if reg_m < -np.max(M): - return np.zeros((len(a), len(b))) + reg_m = float(nx.max(M)) + 1 + if reg_m < -nx.max(M): + return nx.zeros((len(a), len(b)), type_as=M) + + a0, b0, M0 = a, b, M + # convert to humpy + a, b, M = nx.to_numpy(a, b, M) eps = 1e-20 M = np.asarray(M, dtype=np.float64) @@ -149,10 +159,16 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, gamma = np.zeros((len(a), len(b))) gamma[np.ix_(idx_x, idx_y)] = gamma_extended[:-nb_dummies, :-nb_dummies] + # convert back to backend + gamma = nx.from_numpy(gamma, type_as=M0) + if log_emd['warning'] is not None: raise ValueError("Error in the EMD resolution: try to increase the" " number of dummy points") - log_emd['cost'] = np.sum(gamma * M) + log_emd['cost'] = nx.sum(gamma * M0) + log_emd['u'] = nx.from_numpy(log_emd['u'], type_as=a0) + log_emd['v'] = nx.from_numpy(log_emd['v'], type_as=b0) + if log: return gamma, log_emd else: @@ -250,15 +266,23 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): entropic regularization parameter """ + a, b, M = list_to_array(a, b, M) + + nx = get_backend(a, b, M) + if m is None: return partial_wasserstein_lagrange(a, b, M, log=log, **kwargs) elif m < 0: raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") - elif m > np.min((np.sum(a), np.sum(b))): + elif m > nx.min((nx.sum(a), nx.sum(b))): raise ValueError("Problem infeasible. Parameter m should lower or" " equal than min(|a|_1, |b|_1).") + a0, b0, M0 = a, b, M + # convert to humpy + a, b, M = nx.to_numpy(a, b, M) + b_extended = np.append(b, [(np.sum(a) - m) / nb_dummies] * nb_dummies) a_extended = np.append(a, [(np.sum(b) - m) / nb_dummies] * nb_dummies) M_extended = np.zeros((len(a_extended), len(b_extended))) @@ -267,15 +291,20 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): gamma, log_emd = emd(a_extended, b_extended, M_extended, log=True, **kwargs) + + gamma = nx.from_numpy(gamma[:len(a), :len(b)], type_as=M) + if log_emd['warning'] is not None: raise ValueError("Error in the EMD resolution: try to increase the" " number of dummy points") - log_emd['partial_w_dist'] = np.sum(M * gamma[:len(a), :len(b)]) + log_emd['partial_w_dist'] = nx.sum(M0 * gamma) + log_emd['u'] = nx.from_numpy(log_emd['u'][:len(a)], type_as=a0) + log_emd['v'] = nx.from_numpy(log_emd['v'][:len(b)], type_as=b0) if log: - return gamma[:len(a), :len(b)], log_emd + return gamma, log_emd else: - return gamma[:len(a), :len(b)] + return gamma def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): diff --git a/ot/smooth.py b/ot/smooth.py index 6855005..8e0ef38 100644 --- a/ot/smooth.py +++ b/ot/smooth.py @@ -44,6 +44,7 @@ Original code from https://github.com/mblondel/smooth-ot/ import numpy as np from scipy.optimize import minimize +from .backend import get_backend def projection_simplex(V, z=1, axis=None): @@ -511,6 +512,8 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, """ + nx = get_backend(a, b, M) + if reg_type.lower() in ['l2', 'squaredl2']: regul = SquaredL2(gamma=reg) elif reg_type.lower() in ['entropic', 'negentropy', 'kl']: @@ -518,15 +521,19 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, else: raise NotImplementedError('Unknown regularization') + a0, b0, M0 = a, b, M + # convert to humpy + a, b, M = nx.to_numpy(a, b, M) + # solve dual alpha, beta, res = solve_dual(a, b, M, regul, max_iter=numItermax, tol=stopThr, verbose=verbose) # reconstruct transport matrix - G = get_plan_from_dual(alpha, beta, M, regul) + G = nx.from_numpy(get_plan_from_dual(alpha, beta, M, regul), type_as=M0) if log: - log = {'alpha': alpha, 'beta': beta, 'res': res} + log = {'alpha': nx.from_numpy(alpha, type_as=a0), 'beta': nx.from_numpy(beta, type_as=b0), 'res': res} return G, log else: return G diff --git a/ot/solvers.py b/ot/solvers.py new file mode 100644 index 0000000..0294d71 --- /dev/null +++ b/ot/solvers.py @@ -0,0 +1,347 @@ +# -*- coding: utf-8 -*- +""" +General OT solvers with unified API +""" + +# Author: Remi Flamary +# +# License: MIT License + +from .utils import OTResult +from .lp import emd2 +from .backend import get_backend +from .unbalanced import mm_unbalanced, sinkhorn_knopp_unbalanced, lbfgsb_unbalanced +from .bregman import sinkhorn_log +from .partial import partial_wasserstein_lagrange +from .smooth import smooth_ot_dual + + +def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, + unbalanced_type='KL', n_threads=1, max_iter=None, plan_init=None, + potentials_init=None, tol=None, verbose=False): + r"""Solve the discrete optimal transport problem and return :any:`OTResult` object + + The function solves the following general optimal transport problem + + .. math:: + \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + + The regularization is selected with :any:`reg` (:math:`\lambda_r`) and :any:`reg_type`. By + default ``reg=None`` and there is no regularization. The unbalanced marginal + penalization can be selected with :any:`unbalanced` (:math:`\lambda_u`) and + :any:`unbalanced_type`. By default ``unbalanced=None`` and the function + solves the exact optimal transport problem (respecting the marginals). + + Parameters + ---------- + M : array_like, shape (dim_a, dim_b) + Loss matrix + a : array-like, shape (dim_a,), optional + Samples weights in the source domain (default is uniform) + b : array-like, shape (dim_b,), optional + Samples weights in the source domain (default is uniform) + reg : float, optional + Regularization weight :math:`\lambda_r`, by default None (no reg., exact + OT) + reg_type : str, optional + Type of regularization :math:`R` either "KL", "L2", 'entropy', by default "KL" + unbalanced : float, optional + Unbalanced penalization weight :math:`\lambda_u`, by default None + (balanced OT) + unbalanced_type : str, optional + Type of unbalanced penalization unction :math:`U` either "KL", "L2", 'TV', by default 'KL' + n_threads : int, optional + Number of OMP threads for exact OT solver, by default 1 + max_iter : int, optional + Maximum number of iteration, by default None (default values in each solvers) + plan_init : array_like, shape (dim_a, dim_b), optional + Initialization of the OT plan for iterative methods, by default None + potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional + Initialization of the OT dual potentials for iterative methods, by default None + tol : _type_, optional + Tolerance for solution precision, by default None (default values in each solvers) + verbose : bool, optional + Print information in the solver, by default False + + Returns + ------- + res : OTResult() + Result of the optimization problem. The information can be obtained as follows: + + - res.plan : OT plan :math:`\mathbf{T}` + - res.potentials : OT dual potentials + - res.value : Optimal value of the optimization problem + - res.value_linear : Linear OT loss with the optimal OT plan + + See :any:`OTResult` for more information. + + Notes + ----- + + The following methods are available for solving the OT problems: + + - **Classical exact OT problem** (default parameters): + + .. math:: + \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T}^T \mathbf{1} = \mathbf{b} + + \mathbf{T} \geq 0 + + can be solved with the following code: + + .. code-block:: python + + res = ot.solve(M, a, b) + + - **Entropic regularized OT** (when ``reg!=None``): + + .. math:: + \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T}) + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T}^T \mathbf{1} = \mathbf{b} + + \mathbf{T} \geq 0 + + can be solved with the following code: + + .. code-block:: python + + # default is ``"KL"`` regularization (``reg_type="KL"``) + res = ot.solve(M, a, b, reg=1.0) + # or for original Sinkhorn paper formulation [2] + res = ot.solve(M, a, b, reg=1.0, reg_type='entropy') + + - **Quadratic regularized OT** (when ``reg!=None`` and ``reg_type="L2"``): + + .. math:: + \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T}) + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T}^T \mathbf{1} = \mathbf{b} + + \mathbf{T} \geq 0 + + can be solved with the following code: + + .. code-block:: python + + res = ot.solve(M,a,b,reg=1.0,reg_type='L2') + + - **Unbalanced OT** (when ``unbalanced!=None``): + + .. math:: + \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + + can be solved with the following code: + + .. code-block:: python + + # default is ``"KL"`` + res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0) + # quadratic unbalanced OT + res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0,unbalanced_type='L2') + # TV = partial OT + res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0,unbalanced_type='TV') + + + - **Regularized unbalanced regularized OT** (when ``unbalanced!=None`` and ``reg!=None``): + + .. math:: + \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + + can be solved with the following code: + + .. code-block:: python + + # default is ``"KL"`` for both + res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0) + # quadratic unbalanced OT with KL regularization + res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0,unbalanced_type='L2') + # both quadratic + res = ot.solve(M,a,b,reg=1.0, reg_type='L2',unbalanced=1.0,unbalanced_type='L2') + + + .. _references-solve: + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation + of Optimal Transport, Advances in Neural Information Processing + Systems (NIPS) 26, 2013 + + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. + arXiv preprint arXiv:1607.05816. + + .. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse + Optimal Transport. Proceedings of the Twenty-First International + Conference on Artificial Intelligence and Statistics (AISTATS). + + .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, + A., & Peyré, G. (2019, April). Interpolating between optimal transport + and MMD using Sinkhorn divergences. In The 22nd International Conference + on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. + + """ + + # detect backend + arr = [M] + if a is not None: + arr.append(a) + if b is not None: + arr.append(b) + nx = get_backend(*arr) + + # create uniform weights if not given + if a is None: + a = nx.ones(M.shape[0], type_as=M) / M.shape[0] + if b is None: + b = nx.ones(M.shape[1], type_as=M) / M.shape[1] + + # default values for solutions + potentials = None + value = None + value_linear = None + plan = None + status = None + + if reg is None or reg == 0: # exact OT + + if unbalanced is None: # Exact balanced OT + + # default values for EMD solver + if max_iter is None: + max_iter = 1000000 + + value_linear, log = emd2(a, b, M, numItermax=max_iter, log=True, return_matrix=True, numThreads=n_threads) + + value = value_linear + potentials = (log['u'], log['v']) + plan = log['G'] + status = log["warning"] if log["warning"] is not None else 'Converged' + + elif unbalanced_type.lower() in ['kl', 'l2']: # unbalanced exact OT + + # default values for exact unbalanced OT + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-12 + + plan, log = mm_unbalanced(a, b, M, reg_m=unbalanced, + div=unbalanced_type.lower(), numItermax=max_iter, + stopThr=tol, log=True, + verbose=verbose, G0=plan_init) + + value_linear = log['cost'] + + if unbalanced_type.lower() == 'kl': + value = value_linear + unbalanced * (nx.kl_div(nx.sum(plan, 1), a) + nx.kl_div(nx.sum(plan, 0), b)) + else: + err_a = nx.sum(plan, 1) - a + err_b = nx.sum(plan, 0) - b + value = value_linear + unbalanced * nx.sum(err_a**2) + unbalanced * nx.sum(err_b**2) + + elif unbalanced_type.lower() == 'tv': + + if max_iter is None: + max_iter = 1000000 + + plan, log = partial_wasserstein_lagrange(a, b, M, reg_m=unbalanced**2, log=True, numItermax=max_iter) + + value_linear = nx.sum(M * plan) + err_a = nx.sum(plan, 1) - a + err_b = nx.sum(plan, 0) - b + value = value_linear + nx.sqrt(unbalanced**2 / 2.0 * (nx.sum(nx.abs(err_a)) + + nx.sum(nx.abs(err_b)))) + + else: + raise (NotImplementedError('Unknown unbalanced_type="{}"'.format(unbalanced_type))) + + else: # regularized OT + + if unbalanced is None: # Balanced regularized OT + + if reg_type.lower() in ['entropy', 'kl']: + + # default values for sinkhorn + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + + plan, log = sinkhorn_log(a, b, M, reg=reg, numItermax=max_iter, + stopThr=tol, log=True, + verbose=verbose) + + value_linear = nx.sum(M * plan) + + if reg_type.lower() == 'entropy': + value = value_linear + reg * nx.sum(plan * nx.log(plan + 1e-16)) + else: + value = value_linear + reg * nx.kl_div(plan, a[:, None] * b[None, :]) + + potentials = (log['log_u'], log['log_v']) + + elif reg_type.lower() == 'l2': + + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + + plan, log = smooth_ot_dual(a, b, M, reg=reg, numItermax=max_iter, stopThr=tol, log=True, verbose=verbose) + + value_linear = nx.sum(M * plan) + value = value_linear + reg * nx.sum(plan**2) + potentials = (log['alpha'], log['beta']) + + else: + raise (NotImplementedError('Not implemented reg_type="{}"'.format(reg_type))) + + else: # unbalanced AND regularized OT + + if reg_type.lower() in ['kl'] and unbalanced_type.lower() == 'kl': + + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + + plan, log = sinkhorn_knopp_unbalanced(a, b, M, reg=reg, reg_m=unbalanced, numItermax=max_iter, stopThr=tol, verbose=verbose, log=True) + + value_linear = nx.sum(M * plan) + + value = value_linear + reg * nx.kl_div(plan, a[:, None] * b[None, :]) + unbalanced * (nx.kl_div(nx.sum(plan, 1), a) + nx.kl_div(nx.sum(plan, 0), b)) + + potentials = (log['logu'], log['logv']) + + elif reg_type.lower() in ['kl', 'l2', 'entropy'] and unbalanced_type.lower() in ['kl', 'l2']: + + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-12 + + plan, log = lbfgsb_unbalanced(a, b, M, reg=reg, reg_m=unbalanced, reg_div=reg_type.lower(), regm_div=unbalanced_type.lower(), numItermax=max_iter, stopThr=tol, verbose=verbose, log=True) + + value_linear = nx.sum(M * plan) + + value = log['loss'] + + else: + raise (NotImplementedError('Not implemented reg_type="{}" and unbalanced_type="{}"'.format(reg_type, unbalanced_type))) + + res = OTResult(potentials=potentials, value=value, + value_linear=value_linear, plan=plan, status=status, backend=nx) + + return res diff --git a/ot/unbalanced.py b/ot/unbalanced.py index dd9a36e..a71a0dd 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -10,6 +10,9 @@ Regularized Unbalanced OT solvers from __future__ import division import warnings +import numpy as np +from scipy.optimize import minimize, Bounds + from .backend import get_backend from .utils import list_to_array # from .utils import unif, dist @@ -269,7 +272,8 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs): r""" - Solve the entropic regularization unbalanced optimal transport problem and return the loss + Solve the entropic regularization unbalanced optimal transport problem and + return the OT plan The function solves the following optimization problem: @@ -734,7 +738,7 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, if weights is None: weights = nx.ones(n_hists, type_as=A) / n_hists else: - assert len(weights) == A.shape[1] + assert (len(weights) == A.shape[1]) if log: log = {'err': []} @@ -882,7 +886,7 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, if weights is None: weights = nx.ones(n_hists, type_as=A) / n_hists else: - assert len(weights) == A.shape[1] + assert (len(weights) == A.shape[1]) if log: log = {'err': []} @@ -1252,3 +1256,182 @@ def mm_unbalanced2(a, b, M, reg_m, div='kl', G0=None, numItermax=1000, return log_mm['cost'], log_mm else: return log_mm['cost'] + + +def _get_loss_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl'): + """ + return the loss function (scipy.optimize compatible) for regularized + unbalanced OT + """ + + m, n = M.shape + + def kl(p, q): + return np.sum(p * np.log(p / q + 1e-16)) + + def reg_l2(G): + return np.sum((G - a[:, None] * b[None, :])**2) / 2 + + def grad_l2(G): + return G - a[:, None] * b[None, :] + + def reg_kl(G): + return kl(G, a[:, None] * b[None, :]) + + def grad_kl(G): + return np.log(G / (a[:, None] * b[None, :]) + 1e-16) + 1 + + def reg_entropy(G): + return kl(G, 1) + + def grad_entropy(G): + return np.log(G + 1e-16) + 1 + + if reg_div == 'kl': + reg_fun = reg_kl + grad_reg_fun = grad_kl + elif reg_div == 'entropy': + reg_fun = reg_entropy + grad_reg_fun = grad_entropy + else: + reg_fun = reg_l2 + grad_reg_fun = grad_l2 + + def marg_l2(G): + return 0.5 * np.sum((G.sum(1) - a)**2) + 0.5 * np.sum((G.sum(0) - b)**2) + + def grad_marg_l2(G): + return np.outer((G.sum(1) - a), np.ones(n)) + np.outer(np.ones(m), (G.sum(0) - b)) + + def marg_kl(G): + return kl(G.sum(1), a) + kl(G.sum(0), b) + + def grad_marg_kl(G): + return np.outer(np.log(G.sum(1) / a + 1e-16) + 1, np.ones(n)) + np.outer(np.ones(m), np.log(G.sum(0) / b + 1e-16) + 1) + + if regm_div == 'kl': + regm_fun = marg_kl + grad_regm_fun = grad_marg_kl + else: + regm_fun = marg_l2 + grad_regm_fun = grad_marg_l2 + + def _func(G): + G = G.reshape((m, n)) + + # compute loss + val = np.sum(G * M) + reg * reg_fun(G) + reg_m * regm_fun(G) + + # compute gradient + grad = M + reg * grad_reg_fun(G) + reg_m * grad_regm_fun(G) + + return val, grad.ravel() + + return _func + + +def lbfgsb_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl', G0=None, numItermax=1000, + stopThr=1e-15, method='L-BFGS-B', verbose=False, log=False): + r""" + Solve the unbalanced optimal transport problem and return the OT plan using L-BFGS-B. + The function solves the following optimization problem: + + .. math:: + W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + + \mathrm{reg} \mathrm{div}(\gamma,\mathbf{a}\mathbf{b}^T) + \mathrm{reg_m} \cdot \mathrm{div_m}(\gamma \mathbf{1}, \mathbf{a}) + + \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + + s.t. + \gamma \geq 0 + + where: + + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + unbalanced distributions + - :math:`\mathrm{div}` is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence + + The algorithm used for solving the problem is a L-BFGS-B from scipy.optimize + + Parameters + ---------- + a : array-like (dim_a,) + Unnormalized histogram of dimension `dim_a` + b : array-like (dim_b,) + Unnormalized histogram of dimension `dim_b` + M : array-like (dim_a, dim_b) + loss matrix + reg: float + regularization term (>=0) + reg_m: float + Marginal relaxation term >= 0 + reg_div: string, optional + Divergence used for regularization. + Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) + reg_div: string, optional + Divergence to quantify the difference between the marginals. + Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) + G0: array-like (dim_a, dim_b) + Initialization of the transport matrix + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (> 0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + ot_distance : array-like + the OT distance between :math:`\mathbf{a}` and :math:`\mathbf{b}` + log : dict + log dictionary returned only if `log` is `True` + + Examples + -------- + >>> import ot + >>> import numpy as np + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[1., 36.],[9., 4.]] + >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 1, 'l2'),2) + 0.25 + >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 1, 'kl'),2) + 0.57 + + References + ---------- + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. NeurIPS. + See Also + -------- + ot.lp.emd2 : Unregularized OT loss + ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss + """ + nx = get_backend(M, a, b) + + M0 = M + # convert to humpy + a, b, M = nx.to_numpy(a, b, M) + + if G0 is not None: + G0 = nx.to_numpy(G0) + else: + G0 = np.zeros(M.shape) + + _func = _get_loss_unbalanced(a, b, M, reg, reg_m, reg_div, regm_div) + + res = minimize(_func, G0.ravel(), method=method, jac=True, bounds=Bounds(0, np.inf), + tol=stopThr, options=dict(maxiter=numItermax, disp=verbose)) + + G = nx.from_numpy(res.x.reshape(M.shape), type_as=M0) + + if log: + log = {'loss': nx.from_numpy(res.fun, type_as=M0), 'res': res} + return G, log + else: + return G diff --git a/ot/utils.py b/ot/utils.py index e3437da..9093f09 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -15,7 +15,7 @@ from scipy.spatial.distance import cdist import sys import warnings from inspect import signature -from .backend import get_backend, Backend +from .backend import get_backend, Backend, NumpyBackend __time_tic_toc = time.time() @@ -611,3 +611,203 @@ class UndefinedParameter(Exception): """ pass + + +class OTResult: + def __init__(self, potentials=None, value=None, value_linear=None, plan=None, log=None, backend=None, sparse_plan=None, lazy_plan=None, status=None): + + self._potentials = potentials + self._value = value + self._value_linear = value_linear + self._plan = plan + self._log = log + self._sparse_plan = sparse_plan + self._lazy_plan = lazy_plan + self._backend = backend if backend is not None else NumpyBackend() + self._status = status + + # I assume that other solvers may return directly + # some primal objects? + # In the code below, let's define the main quantities + # that may be of interest to users. + # An OT solver returns an object that inherits from OTResult + # (e.g. SinkhornOTResult) and implements the relevant + # methods (e.g. "plan" and "lazy_plan" but not "sparse_plan", etc.). + # log is a dictionary containing potential information about the solver + + # Dual potentials -------------------------------------------- + + def __repr__(self): + s = 'OTResult(' + if self._value is not None: + s += 'value={},'.format(self._value) + if self._value_linear is not None: + s += 'value_linear={},'.format(self._value_linear) + if self._plan is not None: + s += 'plan={}(shape={}),'.format(self._plan.__class__.__name__, self._plan.shape) + + if s[-1] != '(': + s = s[:-1] + ')' + else: + s = s + ')' + return s + + @property + def potentials(self): + """Dual potentials, i.e. Lagrange multipliers for the marginal constraints. + + This pair of arrays has the same shape, numerical type + and properties as the input weights "a" and "b". + """ + if self._potentials is not None: + return self._potentials + else: + raise NotImplementedError() + + @property + def potential_a(self): + """First dual potential, associated to the "source" measure "a".""" + if self._potentials is not None: + return self._potentials[0] + else: + raise NotImplementedError() + + @property + def potential_b(self): + """Second dual potential, associated to the "target" measure "b".""" + if self._potentials is not None: + return self._potentials[1] + else: + raise NotImplementedError() + + # Transport plan ------------------------------------------- + @property + def plan(self): + """Transport plan, encoded as a dense array.""" + # N.B.: We may catch out-of-memory errors and suggest + # the use of lazy_plan or sparse_plan when appropriate. + + if self._plan is not None: + return self._plan + else: + raise NotImplementedError() + + @property + def sparse_plan(self): + """Transport plan, encoded as a sparse array.""" + if self._sparse_plan is not None: + return self._sparse_plan + elif self._plan is not None: + return self._backend.tocsr(self._plan) + else: + raise NotImplementedError() + + @property + def lazy_plan(self): + """Transport plan, encoded as a symbolic KeOps LazyTensor.""" + raise NotImplementedError() + + # Loss values -------------------------------- + + @property + def value(self): + """Full transport cost, including possible regularization terms.""" + if self._value is not None: + return self._value + else: + raise NotImplementedError() + + @property + def value_linear(self): + """The "minimal" transport cost, i.e. the product between the transport plan and the cost.""" + if self._value_linear is not None: + return self._value_linear + else: + raise NotImplementedError() + + # Marginal constraints ------------------------- + @property + def marginals(self): + """Marginals of the transport plan: should be very close to "a" and "b" + for balanced OT.""" + if self._plan is not None: + return self.marginal_a, self.marginal_b + else: + raise NotImplementedError() + + @property + def marginal_a(self): + """First marginal of the transport plan, with the same shape as "a".""" + if self._plan is not None: + return self._backend.sum(self._plan, 1) + else: + raise NotImplementedError() + + @property + def marginal_b(self): + """Second marginal of the transport plan, with the same shape as "b".""" + if self._plan is not None: + return self._backend.sum(self._plan, 0) + else: + raise NotImplementedError() + + @property + def status(self): + """Optimization status of the solver.""" + if self._status is not None: + return self._status + else: + raise NotImplementedError() + + # Barycentric mappings ------------------------- + # Return the displacement vectors as an array + # that has the same shape as "xa"/"xb" (for samples) + # or "a"/"b" * D (for images)? + + @property + def a_to_b(self): + """Displacement vectors from the first to the second measure.""" + raise NotImplementedError() + + @property + def b_to_a(self): + """Displacement vectors from the second to the first measure.""" + raise NotImplementedError() + + # # Wasserstein barycenters ---------------------- + # @property + # def masses(self): + # """Masses for the Wasserstein barycenter.""" + # raise NotImplementedError() + + # @property + # def samples(self): + # """Sample locations for the Wasserstein barycenter.""" + # raise NotImplementedError() + + # Miscellaneous -------------------------------- + + @property + def citation(self): + """Appropriate citation(s) for this result, in plain text and BibTex formats.""" + + # The string below refers to the POT library: + # successor methods may concatenate the relevant references + # to the original definitions, solvers and underlying numerical backends. + return """POT library: + + POT Python Optimal Transport library, Journal of Machine Learning Research, 22(78):1−8, 2021. + Website: https://pythonot.github.io/ + Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer; + + @article{flamary2021pot, + author = {R{\'e}mi Flamary and Nicolas Courty and Alexandre Gramfort and Mokhtar Z. Alaya and Aur{\'e}lie Boisbunon and Stanislas Chambon and Laetitia Chapel and Adrien Corenflos and Kilian Fatras and Nemo Fournier and L{\'e}o Gautheron and Nathalie T.H. Gayraud and Hicham Janati and Alain Rakotomamonjy and Ievgen Redko and Antoine Rolet and Antony Schutz and Vivien Seguy and Danica J. Sutherland and Romain Tavenard and Alexander Tong and Titouan Vayer}, + title = {{POT}: {Python} {Optimal} {Transport}}, + journal = {Journal of Machine Learning Research}, + year = {2021}, + volume = {22}, + number = {78}, + pages = {1-8}, + url = {http://jmlr.org/papers/v22/20-451.html} + } + """ diff --git a/test/test_backend.py b/test/test_backend.py index 311c075..3628f61 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -274,6 +274,8 @@ def test_empty_backend(): nx.inv(M) with pytest.raises(NotImplementedError): nx.sqrtm(M) + with pytest.raises(NotImplementedError): + nx.kl_div(M, M) with pytest.raises(NotImplementedError): nx.isfinite(M) with pytest.raises(NotImplementedError): @@ -592,6 +594,10 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append("matrix square root") + A = nx.kl_div(nx.abs(Mb), nx.abs(Mb) + 1) + lst_b.append(nx.to_numpy(A)) + lst_name.append("Kullback-Leibler divergence") + A = nx.concatenate([vb, nx.from_numpy(np.array([np.inf, np.nan]))], axis=0) A = nx.isfinite(A) lst_b.append(nx.to_numpy(A)) diff --git a/test/test_partial.py b/test/test_partial.py index 33fc259..ae4a1ab 100755 --- a/test/test_partial.py +++ b/test/test_partial.py @@ -79,6 +79,8 @@ def test_partial_wasserstein_lagrange(): w0, log0 = ot.partial.partial_wasserstein_lagrange(p, q, M, 1, log=True) + w0, log0 = ot.partial.partial_wasserstein_lagrange(p, q, M, 100, log=True) + def test_partial_wasserstein(): diff --git a/test/test_solvers.py b/test/test_solvers.py new file mode 100644 index 0000000..b792aca --- /dev/null +++ b/test/test_solvers.py @@ -0,0 +1,133 @@ +"""Tests for ot solvers""" + +# Author: Remi Flamary +# +# License: MIT License + + +import itertools +import numpy as np +import pytest + +import ot + + +lst_reg = [None, 1.0] +lst_reg_type = ['KL', 'entropy', 'L2'] +lst_unbalanced = [None, 0.9] +lst_unbalanced_type = ['KL', 'L2', 'TV'] + + +def assert_allclose_sol(sol1, sol2): + + lst_attr = ['value', 'value_linear', 'plan', + 'potential_a', 'potential_b', 'marginal_a', 'marginal_b'] + + nx1 = sol1._backend if sol1._backend is not None else ot.backend.NumpyBackend() + nx2 = sol2._backend if sol2._backend is not None else ot.backend.NumpyBackend() + + for attr in lst_attr: + try: + np.allclose(nx1.to_numpy(getattr(sol1, attr)), nx2.to_numpy(getattr(sol2, attr))) + except NotImplementedError: + pass + + +def test_solve(nx): + n_samples_s = 10 + n_samples_t = 7 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples_s, n_features) + y = rng.randn(n_samples_t, n_features) + a = ot.utils.unif(n_samples_s) + b = ot.utils.unif(n_samples_t) + + M = ot.dist(x, y) + + # solve unif weights + sol0 = ot.solve(M) + + print(sol0) + + # solve signe weights + sol = ot.solve(M, a, b) + + # check some attributes + sol.potentials + sol.sparse_plan + sol.marginals + sol.status + + assert_allclose_sol(sol0, sol) + + # solve in backend + ab, bb, Mb = nx.from_numpy(a, b, M) + solb = ot.solve(M, a, b) + + assert_allclose_sol(sol, solb) + + # test not implemented unbalanced and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve(M, unbalanced=1, unbalanced_type='cryptic divergence') + + # test not implemented reg_type and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve(M, reg=1, reg_type='cryptic divergence') + + +@pytest.mark.parametrize("reg,reg_type,unbalanced,unbalanced_type", itertools.product(lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type)) +def test_solve_grid(nx, reg, reg_type, unbalanced, unbalanced_type): + n_samples_s = 10 + n_samples_t = 7 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples_s, n_features) + y = rng.randn(n_samples_t, n_features) + a = ot.utils.unif(n_samples_s) + b = ot.utils.unif(n_samples_t) + + M = ot.dist(x, y) + + try: + + # solve unif weights + sol0 = ot.solve(M, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type) + + # solve signe weights + sol = ot.solve(M, a, b, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type) + + assert_allclose_sol(sol0, sol) + + # solve in backend + ab, bb, Mb = nx.from_numpy(a, b, M) + solb = ot.solve(M, a, b, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type) + + assert_allclose_sol(sol, solb) + except NotImplementedError: + pass + + +def test_solve_not_implemented(nx): + + n_samples_s = 10 + n_samples_t = 7 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples_s, n_features) + y = rng.randn(n_samples_t, n_features) + + M = ot.dist(x, y) + + # test not implemented and check raise + with pytest.raises(NotImplementedError): + ot.solve(M, reg=1.0, reg_type='cryptic divergence') + with pytest.raises(NotImplementedError): + ot.solve(M, unbalanced=1.0, unbalanced_type='cryptic divergence') + + # pairs of incompatible divergences + with pytest.raises(NotImplementedError): + ot.solve(M, reg=1.0, reg_type='kl', unbalanced=1.0, unbalanced_type='tv') diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index fc40df0..b76d738 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -5,6 +5,7 @@ # # License: MIT License +import itertools import numpy as np import ot import pytest @@ -289,6 +290,28 @@ def test_implemented_methods(nx): method=method) +@pytest.mark.parametrize("reg_div,regm_div", itertools.product(['kl', 'l2', 'entropy'], ['kl', 'l2'])) +def test_lbfgsb_unbalanced(nx, reg_div, regm_div): + + np.random.seed(42) + + xs = np.random.randn(5, 2) + xt = np.random.randn(6, 2) + + M = ot.dist(xs, xt) + + a = ot.unif(5) + b = ot.unif(6) + + G, log = ot.unbalanced.lbfgsb_unbalanced(a, b, M, 1, 10, reg_div=reg_div, regm_div=regm_div, log=True, verbose=False) + + ab, bb, Mb = nx.from_numpy(a, b, M) + + Gb, log = ot.unbalanced.lbfgsb_unbalanced(ab, bb, Mb, 1, 10, reg_div=reg_div, regm_div=regm_div, log=True, verbose=False) + + np.testing.assert_allclose(G, nx.to_numpy(Gb)) + + def test_mm_convergence(nx): n = 100 rng = np.random.RandomState(42) diff --git a/test/test_utils.py b/test/test_utils.py index 19b6365..666c157 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -301,3 +301,32 @@ def test_BaseEstimator(): cl.set_params(bibi=10) assert cl.first == 'spam again' + + +def test_OTResult(): + + res = ot.utils.OTResult() + + # test print + print(res) + + # tets get citation + print(res.citation) + + lst_attributes = ['a_to_b', + 'b_to_a', + 'lazy_plan', + 'marginal_a', + 'marginal_b', + 'marginals', + 'plan', + 'potential_a', + 'potential_b', + 'potentials', + 'sparse_plan', + 'status', + 'value', + 'value_linear'] + for at in lst_attributes: + with pytest.raises(NotImplementedError): + getattr(res, at) -- cgit v1.2.3 From 80e3c23bc968f866fd20344ddc443a3c7fcb3b0d Mon Sep 17 00:00:00 2001 From: Clément Bonet <32179275+clbonet@users.noreply.github.com> Date: Thu, 23 Feb 2023 08:31:01 +0100 Subject: [WIP] Wasserstein distance on the circle and Spherical Sliced-Wasserstein (#434) * W circle + SSW * Tests + Example SSW_1 * Example Wasserstein Circle + Tests * Wasserstein on the circle wrt Unif * Example SSW unif * pep8 * np.linalg.qr for numpy < 1.22 by batch + add python3.11 to tests * np qr * rm test python 3.11 * update names, tests, backend transpose * Comment error batchs * semidiscrete_wasserstein2_unif_circle example * torch permute method instead of torch.permute for previous versions * update comments and doc * doc wasserstein circle model as [0,1[ * Added ot.utils.get_coordinate_circle to get coordinates on the circle in turn --- CONTRIBUTORS.md | 1 + README.md | 10 +- RELEASES.md | 4 + examples/backends/plot_ssw_unif_torch.py | 153 ++++++ examples/plot_compute_wasserstein_circle.py | 161 ++++++ examples/sliced-wasserstein/plot_variance_ssw.py | 111 ++++ ot/__init__.py | 13 +- ot/backend.py | 204 +++++++- ot/lp/__init__.py | 7 +- ot/lp/solver_1d.py | 627 ++++++++++++++++++++++- ot/sliced.py | 185 ++++++- ot/utils.py | 30 ++ test/test_1d_solver.py | 127 +++++ test/test_backend.py | 46 ++ test/test_sliced.py | 186 +++++++ test/test_utils.py | 10 + 16 files changed, 1852 insertions(+), 23 deletions(-) create mode 100644 examples/backends/plot_ssw_unif_torch.py create mode 100644 examples/plot_compute_wasserstein_circle.py create mode 100644 examples/sliced-wasserstein/plot_variance_ssw.py (limited to 'ot/backend.py') diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 67d8337..1437821 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -41,6 +41,7 @@ The contributors to this library are: * [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug) * [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter) * [Theo Gnassounou](https://github.com/tgnassou) (OT between Gaussian distributions) +* [Clément Bonet](https://clbonet.github.io) (Wassertstein on circle, Spherical Sliced-Wasserstein) ## Acknowledgments diff --git a/README.md b/README.md index 7c9475b..d5e6854 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,8 @@ POT provides the following generic OT solvers (links to examples): * [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3] formulations). * [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36]. +* [Wasserstein distance on the circle](https://pythonot.github.io/auto_examples/plot_compute_wasserstein_circle.html) [44, 45] +* [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46] * [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) [38]. * [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays. @@ -292,4 +294,10 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer [42] Delon, J., Gozlan, N., and Saint-Dizier, A. [Generalized Wasserstein barycenters between probability measures living on different subspaces](https://arxiv.org/pdf/2105.09755). arXiv preprint arXiv:2105.09755, 2021. -[43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. \ No newline at end of file +[43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. + +[44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. [Fast transport optimization for Monge costs on the circle.](https://arxiv.org/abs/0902.3527) SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. + +[45] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. [The statistics of circular optimal transport.](https://arxiv.org/abs/2103.15426) Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82. + +[46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). [Spherical Sliced-Wasserstein](https://openreview.net/forum?id=jXQ0ipgMdU). International Conference on Learning Representations. \ No newline at end of file diff --git a/RELEASES.md b/RELEASES.md index 4ed3625..f8ef653 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -4,6 +4,10 @@ #### New features +- Added the spherical sliced-Wasserstein discrepancy in `ot.sliced.sliced_wasserstein_sphere` and `ot.sliced.sliced_wasserstein_sphere_unif` + examples (PR #434) +- Added the Wasserstein distance on the circle in ``ot.lp.solver_1d.wasserstein_circle`` (PR #434) +- Added the Wasserstein distance on the circle (for p>=1) in `ot.lp.solver_1d.binary_search_circle` + examples (PR #434) +- Added the 2-Wasserstein distance on the circle w.r.t a uniform distribution in `ot.lp.solver_1d.semidiscrete_wasserstein2_unif_circle` (PR #434) - Added Bures Wasserstein distance in `ot.gaussian` (PR ##428) - Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376) - Added Free Support Sinkhorn Barycenter + example (PR #387) diff --git a/examples/backends/plot_ssw_unif_torch.py b/examples/backends/plot_ssw_unif_torch.py new file mode 100644 index 0000000..d1de5a9 --- /dev/null +++ b/examples/backends/plot_ssw_unif_torch.py @@ -0,0 +1,153 @@ +# -*- coding: utf-8 -*- +r""" +================================================ +Spherical Sliced-Wasserstein Embedding on Sphere +================================================ + +Here, we aim at transforming samples into a uniform +distribution on the sphere by minimizing SSW: + +.. math:: + \min_{x} SSW_2(\nu, \frac{1}{n}\sum_{i=1}^n \delta_{x_i}) + +where :math:`\nu=\mathrm{Unif}(S^1)`. + +""" + +# Author: Clément Bonet +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 3 + +import numpy as np +import matplotlib.pyplot as pl +import matplotlib.animation as animation +import torch +import torch.nn.functional as F + +import ot + + +# %% +# Data generation +# --------------- + +torch.manual_seed(1) + +N = 1000 +x0 = torch.rand(N, 3) +x0 = F.normalize(x0, dim=-1) + + +# %% +# Plot data +# --------- + +def plot_sphere(ax): + xlist = np.linspace(-1.0, 1.0, 50) + ylist = np.linspace(-1.0, 1.0, 50) + r = np.linspace(1.0, 1.0, 50) + X, Y = np.meshgrid(xlist, ylist) + + Z = np.sqrt(r**2 - X**2 - Y**2) + + ax.plot_wireframe(X, Y, Z, color="gray", alpha=.3) + ax.plot_wireframe(X, Y, -Z, color="gray", alpha=.3) # Now plot the bottom half + + +# plot the distributions +pl.figure(1) +ax = pl.axes(projection='3d') +plot_sphere(ax) +ax.scatter(x0[:, 0], x0[:, 1], x0[:, 2], label='Data samples', alpha=0.5) +ax.set_title('Data distribution') +ax.legend() + + +# %% +# Gradient descent +# ---------------- + +x = x0.clone() +x.requires_grad_(True) + +n_iter = 500 +lr = 100 + +losses = [] +xvisu = torch.zeros(n_iter, N, 3) + +for i in range(n_iter): + sw = ot.sliced_wasserstein_sphere_unif(x, n_projections=500) + grad_x = torch.autograd.grad(sw, x)[0] + + x = x - lr * grad_x + x = F.normalize(x, p=2, dim=1) + + losses.append(sw.item()) + xvisu[i, :, :] = x.detach().clone() + + if i % 100 == 0: + print("Iter: {:3d}, loss={}".format(i, losses[-1])) + +pl.figure(1) +pl.semilogy(losses) +pl.grid() +pl.title('SSW') +pl.xlabel("Iterations") + + +# %% +# Plot trajectories of generated samples along iterations +# ------------------------------------------------------- + +ivisu = [0, 25, 50, 75, 100, 150, 200, 350, 499] + +fig = pl.figure(3, (10, 10)) +for i in range(9): + # pl.subplot(3, 3, i + 1) + # ax = pl.axes(projection='3d') + ax = fig.add_subplot(3, 3, i + 1, projection='3d') + plot_sphere(ax) + ax.scatter(xvisu[ivisu[i], :, 0], xvisu[ivisu[i], :, 1], xvisu[ivisu[i], :, 2], label='Data samples', alpha=0.5) + ax.set_title('Iter. {}'.format(ivisu[i])) + #ax.axis("off") + if i == 0: + ax.legend() + + +# %% +# Animate trajectories of generated samples along iteration +# ------------------------------------------------------- + +pl.figure(4, (8, 8)) + + +def _update_plot(i): + i = 3 * i + pl.clf() + ax = pl.axes(projection='3d') + plot_sphere(ax) + ax.scatter(xvisu[i, :, 0], xvisu[i, :, 1], xvisu[i, :, 2], label='Data samples$', alpha=0.5) + ax.axis("off") + ax.set_xlim((-1.5, 1.5)) + ax.set_ylim((-1.5, 1.5)) + ax.set_title('Iter. {}'.format(i)) + return 1 + + +print(xvisu.shape) + +i = 0 +ax = pl.axes(projection='3d') +plot_sphere(ax) +ax.scatter(xvisu[i, :, 0], xvisu[i, :, 1], xvisu[i, :, 2], label='Data samples from $G\#\mu_n$', alpha=0.5) +ax.axis("off") +ax.set_xlim((-1.5, 1.5)) +ax.set_ylim((-1.5, 1.5)) +ax.set_title('Iter. {}'.format(ivisu[i])) + + +ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter // 5, interval=100, repeat_delay=2000) +# %% diff --git a/examples/plot_compute_wasserstein_circle.py b/examples/plot_compute_wasserstein_circle.py new file mode 100644 index 0000000..3ede96f --- /dev/null +++ b/examples/plot_compute_wasserstein_circle.py @@ -0,0 +1,161 @@ +# -*- coding: utf-8 -*- +""" +========================= +OT distance on the Circle +========================= + +Shows how to compute the Wasserstein distance on the circle + + +""" + +# Author: Clément Bonet +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 2 + +import numpy as np +import matplotlib.pylab as pl +import ot + +from scipy.special import iv + +############################################################################## +# Plot data +# --------- + +#%% plot the distributions + + +def pdf_von_Mises(theta, mu, kappa): + pdf = np.exp(kappa * np.cos(theta - mu)) / (2.0 * np.pi * iv(0, kappa)) + return pdf + + +t = np.linspace(0, 2 * np.pi, 1000, endpoint=False) + +mu1 = 1 +kappa1 = 20 + +mu_targets = np.linspace(mu1, mu1 + 2 * np.pi, 10) + + +pdf1 = pdf_von_Mises(t, mu1, kappa1) + + +pl.figure(1) +for k, mu in enumerate(mu_targets): + pdf_t = pdf_von_Mises(t, mu, kappa1) + if k == 0: + label = "Source distributions" + else: + label = None + pl.plot(t / (2 * np.pi), pdf_t, c='b', label=label) + +pl.plot(t / (2 * np.pi), pdf1, c="r", label="Target distribution") +pl.legend() + +mu2 = 0 +kappa2 = kappa1 + +x1 = np.random.vonmises(mu1, kappa1, size=(10,)) + np.pi +x2 = np.random.vonmises(mu2, kappa2, size=(10,)) + np.pi + +angles = np.linspace(0, 2 * np.pi, 150) + +pl.figure(2) +pl.plot(np.cos(angles), np.sin(angles), c="k") +pl.xlim(-1.25, 1.25) +pl.ylim(-1.25, 1.25) +pl.scatter(np.cos(x1), np.sin(x1), c="b") +pl.scatter(np.cos(x2), np.sin(x2), c="r") + +######################################################################################### +# Compare the Euclidean Wasserstein distance with the Wasserstein distance on the circle +# --------------------------------------------------------------------------------------- +# This examples illustrates the periodicity of the Wasserstein distance on the circle. +# We choose as target distribution a von Mises distribution with mean :math:`\mu_{\mathrm{target}}` +# and :math:`\kappa=20`. Then, we compare the distances with samples obtained from a von Mises distribution +# with parameters :math:`\mu_{\mathrm{source}}` and :math:`\kappa=20`. +# The Wasserstein distance on the circle takes into account the periodicity +# and attains its maximum in :math:`\mu_{\mathrm{target}}+1` (the antipodal point) contrary to the +# Euclidean version. + +#%% Compute and plot distributions + +mu_targets = np.linspace(0, 2 * np.pi, 200) +xs = np.random.vonmises(mu1 - np.pi, kappa1, size=(500,)) + np.pi + +n_try = 5 + +xts = np.zeros((n_try, 200, 500)) +for i in range(n_try): + for k, mu in enumerate(mu_targets): + # np.random.vonmises deals with data on [-pi, pi[ + xt = np.random.vonmises(mu - np.pi, kappa2, size=(500,)) + np.pi + xts[i, k] = xt + +# Put data on S^1=[0,1[ +xts2 = xts / (2 * np.pi) +xs2 = np.concatenate([xs[None] for k in range(200)], axis=0) / (2 * np.pi) + +L_w2_circle = np.zeros((n_try, 200)) +L_w2 = np.zeros((n_try, 200)) + +for i in range(n_try): + w2_circle = ot.wasserstein_circle(xs2.T, xts2[i].T, p=2) + w2 = ot.wasserstein_1d(xs2.T, xts2[i].T, p=2) + + L_w2_circle[i] = w2_circle + L_w2[i] = w2 + +m_w2_circle = np.mean(L_w2_circle, axis=0) +std_w2_circle = np.std(L_w2_circle, axis=0) + +m_w2 = np.mean(L_w2, axis=0) +std_w2 = np.std(L_w2, axis=0) + +pl.figure(1) +pl.plot(mu_targets / (2 * np.pi), m_w2_circle, label="Wasserstein circle") +pl.fill_between(mu_targets / (2 * np.pi), m_w2_circle - 2 * std_w2_circle, m_w2_circle + 2 * std_w2_circle, alpha=0.5) +pl.plot(mu_targets / (2 * np.pi), m_w2, label="Euclidean Wasserstein") +pl.fill_between(mu_targets / (2 * np.pi), m_w2 - 2 * std_w2, m_w2 + 2 * std_w2, alpha=0.5) +pl.vlines(x=[mu1 / (2 * np.pi)], ymin=0, ymax=np.max(w2), linestyle="--", color="k", label=r"$\mu_{\mathrm{target}}$") +pl.legend() +pl.xlabel(r"$\mu_{\mathrm{source}}$") +pl.show() + + +######################################################################## +# Wasserstein distance between von Mises and uniform for different kappa +# ---------------------------------------------------------------------- +# When :math:`\kappa=0`, the von Mises distribution is the uniform distribution on :math:`S^1`. + +#%% Compute Wasserstein between Von Mises and uniform + +kappas = np.logspace(-5, 2, 100) +n_try = 20 + +xts = np.zeros((n_try, 100, 500)) +for i in range(n_try): + for k, kappa in enumerate(kappas): + # np.random.vonmises deals with data on [-pi, pi[ + xt = np.random.vonmises(0, kappa, size=(500,)) + np.pi + xts[i, k] = xt / (2 * np.pi) + +L_w2 = np.zeros((n_try, 100)) +for i in range(n_try): + L_w2[i] = ot.semidiscrete_wasserstein2_unif_circle(xts[i].T) + +m_w2 = np.mean(L_w2, axis=0) +std_w2 = np.std(L_w2, axis=0) + +pl.figure(1) +pl.plot(kappas, m_w2) +pl.fill_between(kappas, m_w2 - std_w2, m_w2 + std_w2, alpha=0.5) +pl.title(r"Evolution of $W_2^2(vM(0,\kappa), Unif(S^1))$") +pl.xlabel(r"$\kappa$") +pl.show() + +# %% diff --git a/examples/sliced-wasserstein/plot_variance_ssw.py b/examples/sliced-wasserstein/plot_variance_ssw.py new file mode 100644 index 0000000..83d458f --- /dev/null +++ b/examples/sliced-wasserstein/plot_variance_ssw.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +""" +==================================================== +Spherical Sliced Wasserstein on distributions in S^2 +==================================================== + +This example illustrates the computation of the spherical sliced Wasserstein discrepancy as +proposed in [46]. + +[46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). 'Spherical Sliced-Wasserstein". International Conference on Learning Representations. + +""" + +# Author: Clément Bonet +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 2 + +import matplotlib.pylab as pl +import numpy as np + +import ot + +############################################################################## +# Generate data +# ------------- + +# %% parameters and data generation + +n = 500 # nb samples + +xs = np.random.randn(n, 3) +xt = np.random.randn(n, 3) + +xs = xs / np.sqrt(np.sum(xs**2, -1, keepdims=True)) +xt = xt / np.sqrt(np.sum(xt**2, -1, keepdims=True)) + +a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples + +############################################################################## +# Plot data +# --------- + +# %% plot samples + +fig = pl.figure(figsize=(10, 10)) +ax = pl.axes(projection='3d') +ax.grid(False) + +u, v = np.mgrid[0:2 * np.pi:30j, 0:np.pi:30j] +x = np.cos(u) * np.sin(v) +y = np.sin(u) * np.sin(v) +z = np.cos(v) +ax.plot_surface(x, y, z, color="gray", alpha=0.03) +ax.plot_wireframe(x, y, z, linewidth=1, alpha=0.25, color="gray") + +ax.scatter(xs[:, 0], xs[:, 1], xs[:, 2], label="Source") +ax.scatter(xt[:, 0], xt[:, 1], xt[:, 2], label="Target") + +fs = 10 +# Labels +ax.set_xlabel('x', fontsize=fs) +ax.set_ylabel('y', fontsize=fs) +ax.set_zlabel('z', fontsize=fs) + +ax.view_init(20, 120) +ax.set_xlim(-1.5, 1.5) +ax.set_ylim(-1.5, 1.5) +ax.set_zlim(-1.5, 1.5) + +# Ticks +ax.set_xticks([-1, 0, 1]) +ax.set_yticks([-1, 0, 1]) +ax.set_zticks([-1, 0, 1]) + +pl.legend(loc=0) +pl.title("Source and Target distribution") + +############################################################################### +# Spherical Sliced Wasserstein for different seeds and number of projections +# -------------------------------------------------------------------------- + +n_seed = 50 +n_projections_arr = np.logspace(0, 3, 25, dtype=int) +res = np.empty((n_seed, 25)) + +# %% Compute statistics +for seed in range(n_seed): + for i, n_projections in enumerate(n_projections_arr): + res[seed, i] = ot.sliced_wasserstein_sphere(xs, xt, a, b, n_projections, seed=seed, p=1) + +res_mean = np.mean(res, axis=0) +res_std = np.std(res, axis=0) + +############################################################################### +# Plot Spherical Sliced Wasserstein +# --------------------------------- + +pl.figure(2) +pl.plot(n_projections_arr, res_mean, label=r"$SSW_1$") +pl.fill_between(n_projections_arr, res_mean - 2 * res_std, res_mean + 2 * res_std, alpha=0.5) + +pl.legend() +pl.xscale('log') + +pl.xlabel("Number of projections") +pl.ylabel("Distance") +pl.title('Spherical Sliced Wasserstein Distance with 95% confidence inverval') + +pl.show() diff --git a/ot/__init__.py b/ot/__init__.py index 0b55e0c..45d5cfa 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -38,12 +38,15 @@ from . import solvers from . import gaussian # OT functions -from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d +from .lp import (emd, emd2, emd_1d, emd2_1d, wasserstein_1d, + binary_search_circle, wasserstein_circle, + semidiscrete_wasserstein2_unif_circle) from .bregman import sinkhorn, sinkhorn2, barycenter from .unbalanced import (sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2) from .da import sinkhorn_lpl1_mm -from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance +from .sliced import (sliced_wasserstein_distance, max_sliced_wasserstein_distance, + sliced_wasserstein_sphere, sliced_wasserstein_sphere_unif) from .gromov import (gromov_wasserstein, gromov_wasserstein2, gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) from .weak import weak_optimal_transport @@ -60,8 +63,10 @@ __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'emd2_1d', 'wasserstein_1d', 'backend', 'gaussian', 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim', 'sinkhorn_unbalanced', 'barycenter_unbalanced', - 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', + 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'sliced_wasserstein_sphere', 'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2', 'max_sliced_wasserstein_distance', 'weak_optimal_transport', 'factored_optimal_transport', 'solve', - 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers'] + 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers', + 'binary_search_circle', 'wasserstein_circle', + 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif'] diff --git a/ot/backend.py b/ot/backend.py index 337e040..0779243 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -534,9 +534,9 @@ class Backend(): """ raise NotImplementedError() - def zero_pad(self, a, pad_width): + def zero_pad(self, a, pad_width, value=0): r""" - Pads a tensor. + Pads a tensor with a given value (0 by default). This function follows the api from :any:`numpy.pad` @@ -895,6 +895,62 @@ class Backend(): """ raise NotImplementedError() + def tile(self, a, reps): + r""" + Construct an array by repeating a the number of times given by reps + + See: https://numpy.org/doc/stable/reference/generated/numpy.tile.html + """ + raise NotImplementedError() + + def floor(self, a): + r""" + Return the floor of the input element-wise + + See: https://numpy.org/doc/stable/reference/generated/numpy.floor.html + """ + raise NotImplementedError() + + def prod(self, a, axis=None): + r""" + Return the product of all elements. + + See: https://numpy.org/doc/stable/reference/generated/numpy.prod.html + """ + raise NotImplementedError() + + def sort2(self, a, axis=None): + r""" + Return the sorted array and the indices to sort the array + + See: https://pytorch.org/docs/stable/generated/torch.sort.html + """ + raise NotImplementedError() + + def qr(self, a): + r""" + Return the QR factorization + + See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.qr.html + """ + raise NotImplementedError() + + def atan2(self, a, b): + r""" + Element wise arctangent + + See: https://numpy.org/doc/stable/reference/generated/numpy.arctan2.html + """ + raise NotImplementedError() + + def transpose(self, a, axes=None): + r""" + Returns a tensor that is a transposed version of a. The given dimensions dim0 and dim1 are swapped. + + See: https://numpy.org/doc/stable/reference/generated/numpy.transpose.html + """ + raise NotImplementedError() + class NumpyBackend(Backend): """ @@ -1039,8 +1095,8 @@ class NumpyBackend(Backend): def concatenate(self, arrays, axis=0): return np.concatenate(arrays, axis) - def zero_pad(self, a, pad_width): - return np.pad(a, pad_width) + def zero_pad(self, a, pad_width, value=0): + return np.pad(a, pad_width, constant_values=value) def argmax(self, a, axis=None): return np.argmax(a, axis=axis) @@ -1185,6 +1241,44 @@ class NumpyBackend(Backend): def is_floating_point(self, a): return a.dtype.kind == "f" + def tile(self, a, reps): + return np.tile(a, reps) + + def floor(self, a): + return np.floor(a) + + def prod(self, a, axis=0): + return np.prod(a, axis=axis) + + def sort2(self, a, axis=-1): + return self.sort(a, axis), self.argsort(a, axis) + + def qr(self, a): + np_version = tuple([int(k) for k in np.__version__.split(".")]) + if np_version < (1, 22, 0): + M, N = a.shape[-2], a.shape[-1] + K = min(M, N) + + if len(a.shape) >= 3: + n = a.shape[0] + + qs, rs = np.zeros((n, M, K)), np.zeros((n, K, N)) + + for i in range(a.shape[0]): + qs[i], rs[i] = np.linalg.qr(a[i]) + + else: + return np.linalg.qr(a) + + return qs, rs + return np.linalg.qr(a) + + def atan2(self, a, b): + return np.arctan2(a, b) + + def transpose(self, a, axes=None): + return np.transpose(a, axes) + class JaxBackend(Backend): """ @@ -1351,8 +1445,8 @@ class JaxBackend(Backend): def concatenate(self, arrays, axis=0): return jnp.concatenate(arrays, axis) - def zero_pad(self, a, pad_width): - return jnp.pad(a, pad_width) + def zero_pad(self, a, pad_width, value=0): + return jnp.pad(a, pad_width, constant_values=value) def argmax(self, a, axis=None): return jnp.argmax(a, axis=axis) @@ -1511,6 +1605,27 @@ class JaxBackend(Backend): def is_floating_point(self, a): return a.dtype.kind == "f" + def tile(self, a, reps): + return jnp.tile(a, reps) + + def floor(self, a): + return jnp.floor(a) + + def prod(self, a, axis=0): + return jnp.prod(a, axis=axis) + + def sort2(self, a, axis=-1): + return self.sort(a, axis), self.argsort(a, axis) + + def qr(self, a): + return jnp.linalg.qr(a) + + def atan2(self, a, b): + return jnp.arctan2(a, b) + + def transpose(self, a, axes=None): + return jnp.transpose(a, axes) + class TorchBackend(Backend): """ @@ -1729,13 +1844,13 @@ class TorchBackend(Backend): def concatenate(self, arrays, axis=0): return torch.cat(arrays, dim=axis) - def zero_pad(self, a, pad_width): + def zero_pad(self, a, pad_width, value=0): from torch.nn.functional import pad # pad_width is an array of ndim tuples indicating how many 0 before and after # we need to add. We first need to make it compliant with torch syntax, that # starts with the last dim, then second last, etc. how_pad = tuple(element for tupl in pad_width[::-1] for element in tupl) - return pad(a, how_pad) + return pad(a, how_pad, value=value) def argmax(self, a, axis=None): return torch.argmax(a, dim=axis) @@ -1934,6 +2049,29 @@ class TorchBackend(Backend): def is_floating_point(self, a): return a.dtype.is_floating_point + def tile(self, a, reps): + return a.repeat(reps) + + def floor(self, a): + return torch.floor(a) + + def prod(self, a, axis=0): + return torch.prod(a, dim=axis) + + def sort2(self, a, axis=-1): + return torch.sort(a, axis) + + def qr(self, a): + return torch.linalg.qr(a) + + def atan2(self, a, b): + return torch.atan2(a, b) + + def transpose(self, a, axes=None): + if axes is None: + axes = tuple(range(a.ndim)[::-1]) + return a.permute(axes) + class CupyBackend(Backend): # pragma: no cover """ @@ -2096,8 +2234,8 @@ class CupyBackend(Backend): # pragma: no cover def concatenate(self, arrays, axis=0): return cp.concatenate(arrays, axis) - def zero_pad(self, a, pad_width): - return cp.pad(a, pad_width) + def zero_pad(self, a, pad_width, value=0): + return cp.pad(a, pad_width, constant_values=value) def argmax(self, a, axis=None): return cp.argmax(a, axis=axis) @@ -2284,6 +2422,27 @@ class CupyBackend(Backend): # pragma: no cover def is_floating_point(self, a): return a.dtype.kind == "f" + def tile(self, a, reps): + return cp.tile(a, reps) + + def floor(self, a): + return cp.floor(a) + + def prod(self, a, axis=0): + return cp.prod(a, axis=axis) + + def sort2(self, a, axis=-1): + return self.sort(a, axis), self.argsort(a, axis) + + def qr(self, a): + return cp.linalg.qr(a) + + def atan2(self, a, b): + return cp.arctan2(a, b) + + def transpose(self, a, axes=None): + return cp.transpose(a, axes) + class TensorflowBackend(Backend): @@ -2454,8 +2613,8 @@ class TensorflowBackend(Backend): def concatenate(self, arrays, axis=0): return tnp.concatenate(arrays, axis) - def zero_pad(self, a, pad_width): - return tnp.pad(a, pad_width, mode="constant") + def zero_pad(self, a, pad_width, value=0): + return tnp.pad(a, pad_width, mode="constant", constant_values=value) def argmax(self, a, axis=None): return tnp.argmax(a, axis=axis) @@ -2646,3 +2805,24 @@ class TensorflowBackend(Backend): def is_floating_point(self, a): return a.dtype.is_floating + + def tile(self, a, reps): + return tnp.tile(a, reps) + + def floor(self, a): + return tf.floor(a) + + def prod(self, a, axis=0): + return tnp.prod(a, axis=axis) + + def sort2(self, a, axis=-1): + return self.sort(a, axis), self.argsort(a, axis) + + def qr(self, a): + return tf.linalg.qr(a) + + def atan2(self, a, b): + return tf.math.atan2(a, b) + + def transpose(self, a, axes=None): + return tf.transpose(a, perm=axes) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 17411d0..7d0640f 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -20,14 +20,17 @@ from .cvx import barycenter # import compiled emd from .emd_wrap import emd_c, check_result, emd_1d_sorted -from .solver_1d import emd_1d, emd2_1d, wasserstein_1d +from .solver_1d import (emd_1d, emd2_1d, wasserstein_1d, + binary_search_circle, wasserstein_circle, + semidiscrete_wasserstein2_unif_circle) from ..utils import dist, list_to_array from ..utils import parmap from ..backend import get_backend __all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted', - 'emd_1d', 'emd2_1d', 'wasserstein_1d', 'generalized_free_support_barycenter'] + 'emd_1d', 'emd2_1d', 'wasserstein_1d', 'generalized_free_support_barycenter', + 'binary_search_circle', 'wasserstein_circle', 'semidiscrete_wasserstein2_unif_circle'] def check_number_threads(numThreads): diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index 43763a9..e7add89 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -53,7 +53,7 @@ def wasserstein_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, requ distributions .. math: - OT_{loss} = \int_0^1 |cdf_u^{-1}(q) cdf_v^{-1}(q)|^p dq + OT_{loss} = \int_0^1 |cdf_u^{-1}(q) - cdf_v^{-1}(q)|^p dq It is formally the p-Wasserstein distance raised to the power p. We do so in a vectorized way by first building the individual quantile functions then integrating them. @@ -365,3 +365,628 @@ def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, log_emd = {'G': G} return cost, log_emd return cost + + +def roll_cols(M, shifts): + r""" + Utils functions which allow to shift the order of each row of a 2d matrix + + Parameters + ---------- + M : (nr, nc) ndarray + Matrix to shift + shifts: int or (nr,) ndarray + + Returns + ------- + Shifted array + + Examples + -------- + >>> M = np.array([[1,2,3],[4,5,6],[7,8,9]]) + >>> roll_cols(M, 2) + array([[2, 3, 1], + [5, 6, 4], + [8, 9, 7]]) + >>> roll_cols(M, np.array([[1],[2],[1]])) + array([[3, 1, 2], + [5, 6, 4], + [9, 7, 8]]) + + References + ---------- + https://stackoverflow.com/questions/66596699/how-to-shift-columns-or-rows-in-a-tensor-with-different-offsets-in-pytorch + """ + nx = get_backend(M) + + n_rows, n_cols = M.shape + + arange1 = nx.tile(nx.reshape(nx.arange(n_cols), (1, n_cols)), (n_rows, 1)) + arange2 = (arange1 - shifts) % n_cols + + return nx.take_along_axis(M, arange2, 1) + + +def derivative_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p=2): + r""" Computes the left and right derivative of the cost (Equation (6.3) and (6.4) of [1]) + + Parameters + ---------- + theta: array-like, shape (n_batch, n) + Cuts on the circle + u_values: array-like, shape (n_batch, n) + locations of the first empirical distribution + v_values: array-like, shape (n_batch, n) + locations of the second empirical distribution + u_cdf: array-like, shape (n_batch, n) + cdf of the first empirical distribution + v_cdf: array-like, shape (n_batch, n) + cdf of the second empirical distribution + p: float, optional = 2 + Power p used for computing the Wasserstein distance + + Returns + ------- + dCp: array-like, shape (n_batch, 1) + The batched right derivative + dCm: array-like, shape (n_batch, 1) + The batched left derivative + + References + --------- + .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. + """ + nx = get_backend(theta, u_values, v_values, u_cdf, v_cdf) + + v_values = nx.copy(v_values) + + n = u_values.shape[-1] + m_batch, m = v_values.shape + + v_cdf_theta = v_cdf - (theta - nx.floor(theta)) + + mask_p = v_cdf_theta >= 0 + mask_n = v_cdf_theta < 0 + + v_values[mask_n] += nx.floor(theta)[mask_n] + 1 + v_values[mask_p] += nx.floor(theta)[mask_p] + + if nx.any(mask_n) and nx.any(mask_p): + v_cdf_theta[mask_n] += 1 + + v_cdf_theta2 = nx.copy(v_cdf_theta) + v_cdf_theta2[mask_n] = np.inf + shift = (-nx.argmin(v_cdf_theta2, axis=-1)) + + v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1, 1))) + v_values = roll_cols(v_values, nx.reshape(shift, (-1, 1))) + v_values = nx.concatenate([v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1) + + if nx.__name__ == 'torch': + # this is to ensure the best performance for torch searchsorted + # and avoid a warninng related to non-contiguous arrays + u_cdf = u_cdf.contiguous() + v_cdf_theta = v_cdf_theta.contiguous() + + # quantiles of F_u evaluated in F_v^\theta + u_index = nx.searchsorted(u_cdf, v_cdf_theta) + u_icdf_theta = nx.take_along_axis(u_values, nx.clip(u_index, 0, n - 1), -1) + + # Deal with 1 + u_cdfm = nx.concatenate([u_cdf, nx.reshape(u_cdf[:, 0], (-1, 1)) + 1], axis=1) + u_valuesm = nx.concatenate([u_values, nx.reshape(u_values[:, 0], (-1, 1)) + 1], axis=1) + + if nx.__name__ == 'torch': + # this is to ensure the best performance for torch searchsorted + # and avoid a warninng related to non-contiguous arrays + u_cdfm = u_cdfm.contiguous() + v_cdf_theta = v_cdf_theta.contiguous() + + u_indexm = nx.searchsorted(u_cdfm, v_cdf_theta, side="right") + u_icdfm_theta = nx.take_along_axis(u_valuesm, nx.clip(u_indexm, 0, n), -1) + + dCp = nx.sum(nx.power(nx.abs(u_icdf_theta - v_values[:, 1:]), p) + - nx.power(nx.abs(u_icdf_theta - v_values[:, :-1]), p), axis=-1) + + dCm = nx.sum(nx.power(nx.abs(u_icdfm_theta - v_values[:, 1:]), p) + - nx.power(nx.abs(u_icdfm_theta - v_values[:, :-1]), p), axis=-1) + + return dCp.reshape(-1, 1), dCm.reshape(-1, 1) + + +def ot_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p): + r""" Computes the the cost (Equation (6.2) of [1]) + + Parameters + ---------- + theta: array-like, shape (n_batch, n) + Cuts on the circle + u_values: array-like, shape (n_batch, n) + locations of the first empirical distribution + v_values: array-like, shape (n_batch, n) + locations of the second empirical distribution + u_cdf: array-like, shape (n_batch, n) + cdf of the first empirical distribution + v_cdf: array-like, shape (n_batch, n) + cdf of the second empirical distribution + p: float, optional = 2 + Power p used for computing the Wasserstein distance + + Returns + ------- + ot_cost: array-like, shape (n_batch,) + OT cost evaluated at theta + + References + --------- + .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. + """ + nx = get_backend(theta, u_values, v_values, u_cdf, v_cdf) + + v_values = nx.copy(v_values) + + m_batch, m = v_values.shape + n_batch, n = u_values.shape + + v_cdf_theta = v_cdf - (theta - nx.floor(theta)) + + mask_p = v_cdf_theta >= 0 + mask_n = v_cdf_theta < 0 + + v_values[mask_n] += nx.floor(theta)[mask_n] + 1 + v_values[mask_p] += nx.floor(theta)[mask_p] + + if nx.any(mask_n) and nx.any(mask_p): + v_cdf_theta[mask_n] += 1 + + # Put negative values at the end + v_cdf_theta2 = nx.copy(v_cdf_theta) + v_cdf_theta2[mask_n] = np.inf + shift = (-nx.argmin(v_cdf_theta2, axis=-1)) + + v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1, 1))) + v_values = roll_cols(v_values, nx.reshape(shift, (-1, 1))) + v_values = nx.concatenate([v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1) + + # Compute absciss + cdf_axis = nx.sort(nx.concatenate((u_cdf, v_cdf_theta), -1), -1) + cdf_axis_pad = nx.zero_pad(cdf_axis, pad_width=[(0, 0), (1, 0)]) + + delta = cdf_axis_pad[..., 1:] - cdf_axis_pad[..., :-1] + + if nx.__name__ == 'torch': + # this is to ensure the best performance for torch searchsorted + # and avoid a warninng related to non-contiguous arrays + u_cdf = u_cdf.contiguous() + v_cdf_theta = v_cdf_theta.contiguous() + cdf_axis = cdf_axis.contiguous() + + # Compute icdf + u_index = nx.searchsorted(u_cdf, cdf_axis) + u_icdf = nx.take_along_axis(u_values, u_index.clip(0, n - 1), -1) + + v_values = nx.concatenate([v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1) + v_index = nx.searchsorted(v_cdf_theta, cdf_axis) + v_icdf = nx.take_along_axis(v_values, v_index.clip(0, m), -1) + + if p == 1: + ot_cost = nx.sum(delta * nx.abs(u_icdf - v_icdf), axis=-1) + else: + ot_cost = nx.sum(delta * nx.power(nx.abs(u_icdf - v_icdf), p), axis=-1) + + return ot_cost + + +def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1, + Lm=10, Lp=10, tm=-1, tp=1, eps=1e-6, require_sort=True, + log=False): + r"""Computes the Wasserstein distance on the circle using the Binary search algorithm proposed in [44]. + Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, + takes the value modulo 1. + If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates + using e.g. the atan2 function. + + .. math:: + W_p^p(u,v) = \inf_{\theta\in\mathbb{R}}\int_0^1 |F_u^{-1}(q) - (F_v-\theta)^{-1}(q)|^p\ \mathrm{d}q + + where: + + - :math:`F_u` and :math:`F_v` are respectively the cdfs of :math:`u` and :math:`v` + + For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with + + .. math:: + u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi} + + using e.g. ot.utils.get_coordinate_circle(x) + + The function runs on backend but tensorflow is not supported. + + Parameters + ---------- + u_values : ndarray, shape (n, ...) + samples in the source domain (coordinates on [0,1[) + v_values : ndarray, shape (n, ...) + samples in the target domain (coordinates on [0,1[) + u_weights : ndarray, shape (n, ...), optional + samples weights in the source domain + v_weights : ndarray, shape (n, ...), optional + samples weights in the target domain + p : float, optional (default=1) + Power p used for computing the Wasserstein distance + Lm : int, optional + Lower bound dC + Lp : int, optional + Upper bound dC + tm: float, optional + Lower bound theta + tp: float, optional + Upper bound theta + eps: float, optional + Stopping condition + require_sort: bool, optional + If True, sort the values. + log: bool, optional + If True, returns also the optimal theta + + Returns + ------- + loss: float + Cost associated to the optimal transportation + log: dict, optional + log dictionary returned only if log==True in parameters + + Examples + -------- + >>> u = np.array([[0.2,0.5,0.8]])%1 + >>> v = np.array([[0.4,0.5,0.7]])%1 + >>> binary_search_circle(u.T, v.T, p=1) + array([0.1]) + + References + ---------- + .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. + .. Matlab Code: https://users.mccme.ru/ansobol/otarie/software.html + """ + assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p) + + if u_weights is not None and v_weights is not None: + nx = get_backend(u_values, v_values, u_weights, v_weights) + else: + nx = get_backend(u_values, v_values) + + n = u_values.shape[0] + m = v_values.shape[0] + + if len(u_values.shape) == 1: + u_values = nx.reshape(u_values, (n, 1)) + if len(v_values.shape) == 1: + v_values = nx.reshape(v_values, (m, 1)) + + if u_values.shape[1] != v_values.shape[1]: + raise ValueError( + "u and v must have the same number of batchs {} and {} respectively given".format(u_values.shape[1], + v_values.shape[1])) + + u_values = u_values % 1 + v_values = v_values % 1 + + if u_weights is None: + u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values) + elif u_weights.ndim != u_values.ndim: + u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) + if v_weights is None: + v_weights = nx.full(v_values.shape, 1. / m, type_as=v_values) + elif v_weights.ndim != v_values.ndim: + v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) + + if require_sort: + u_sorter = nx.argsort(u_values, 0) + u_values = nx.take_along_axis(u_values, u_sorter, 0) + + v_sorter = nx.argsort(v_values, 0) + v_values = nx.take_along_axis(v_values, v_sorter, 0) + + u_weights = nx.take_along_axis(u_weights, u_sorter, 0) + v_weights = nx.take_along_axis(v_weights, v_sorter, 0) + + u_cdf = nx.cumsum(u_weights, 0).T + v_cdf = nx.cumsum(v_weights, 0).T + + u_values = u_values.T + v_values = v_values.T + + L = max(Lm, Lp) + + tm = tm * nx.reshape(nx.ones((u_values.shape[0],), type_as=u_values), (-1, 1)) + tm = nx.tile(tm, (1, m)) + tp = tp * nx.reshape(nx.ones((u_values.shape[0],), type_as=u_values), (-1, 1)) + tp = nx.tile(tp, (1, m)) + tc = (tm + tp) / 2 + + done = nx.zeros((u_values.shape[0], m)) + + cpt = 0 + while nx.any(1 - done): + cpt += 1 + + dCp, dCm = derivative_cost_on_circle(tc, u_values, v_values, u_cdf, v_cdf, p) + done = ((dCp * dCm) <= 0) * 1 + + mask = ((tp - tm) < eps / L) * (1 - done) + + if nx.any(mask): + # can probably be improved by computing only relevant values + dCptp, dCmtp = derivative_cost_on_circle(tp, u_values, v_values, u_cdf, v_cdf, p) + dCptm, dCmtm = derivative_cost_on_circle(tm, u_values, v_values, u_cdf, v_cdf, p) + Ctm = ot_cost_on_circle(tm, u_values, v_values, u_cdf, v_cdf, p).reshape(-1, 1) + Ctp = ot_cost_on_circle(tp, u_values, v_values, u_cdf, v_cdf, p).reshape(-1, 1) + + mask_end = mask * (nx.abs(dCptm - dCmtp) > 0.001) + tc[mask_end > 0] = ((Ctp - Ctm + tm * dCptm - tp * dCmtp) / (dCptm - dCmtp))[mask_end > 0] + done[nx.prod(mask, axis=-1) > 0] = 1 + elif nx.any(1 - done): + tm[((1 - mask) * (dCp < 0)) > 0] = tc[((1 - mask) * (dCp < 0)) > 0] + tp[((1 - mask) * (dCp >= 0)) > 0] = tc[((1 - mask) * (dCp >= 0)) > 0] + tc[((1 - mask) * (1 - done)) > 0] = (tm[((1 - mask) * (1 - done)) > 0] + tp[((1 - mask) * (1 - done)) > 0]) / 2 + + w = ot_cost_on_circle(tc, u_values, v_values, u_cdf, v_cdf, p) + + if log: + return w, {"optimal_theta": tc[:, 0]} + return w + + +def wasserstein1_circle(u_values, v_values, u_weights=None, v_weights=None, require_sort=True): + r"""Computes the 1-Wasserstein distance on the circle using the level median [45]. + Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, + takes the value modulo 1. + If the values are on :math:`S^1\subset\mathbb{R}^2`, first find the coordinates + using e.g. the atan2 function. + The function runs on backend but tensorflow is not supported. + + .. math:: + W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t + + Parameters + ---------- + u_values : ndarray, shape (n, ...) + samples in the source domain (coordinates on [0,1[) + v_values : ndarray, shape (n, ...) + samples in the target domain (coordinates on [0,1[) + u_weights : ndarray, shape (n, ...), optional + samples weights in the source domain + v_weights : ndarray, shape (n, ...), optional + samples weights in the target domain + require_sort: bool, optional + If True, sort the values. + + Returns + ------- + loss: float + Cost associated to the optimal transportation + + Examples + -------- + >>> u = np.array([[0.2,0.5,0.8]])%1 + >>> v = np.array([[0.4,0.5,0.7]])%1 + >>> wasserstein1_circle(u.T, v.T) + array([0.1]) + + References + ---------- + .. [45] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. "The statistics of circular optimal transport." Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82. + .. Code R: https://gitlab.gwdg.de/shundri/circularOT/-/tree/master/ + """ + + if u_weights is not None and v_weights is not None: + nx = get_backend(u_values, v_values, u_weights, v_weights) + else: + nx = get_backend(u_values, v_values) + + n = u_values.shape[0] + m = v_values.shape[0] + + if len(u_values.shape) == 1: + u_values = nx.reshape(u_values, (n, 1)) + if len(v_values.shape) == 1: + v_values = nx.reshape(v_values, (m, 1)) + + if u_values.shape[1] != v_values.shape[1]: + raise ValueError( + "u and v must have the same number of batchs {} and {} respectively given".format(u_values.shape[1], + v_values.shape[1])) + + u_values = u_values % 1 + v_values = v_values % 1 + + if u_weights is None: + u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values) + elif u_weights.ndim != u_values.ndim: + u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) + if v_weights is None: + v_weights = nx.full(v_values.shape, 1. / m, type_as=v_values) + elif v_weights.ndim != v_values.ndim: + v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) + + if require_sort: + u_sorter = nx.argsort(u_values, 0) + u_values = nx.take_along_axis(u_values, u_sorter, 0) + + v_sorter = nx.argsort(v_values, 0) + v_values = nx.take_along_axis(v_values, v_sorter, 0) + + u_weights = nx.take_along_axis(u_weights, u_sorter, 0) + v_weights = nx.take_along_axis(v_weights, v_sorter, 0) + + # Code inspired from https://gitlab.gwdg.de/shundri/circularOT/-/tree/master/ + values_sorted, values_sorter = nx.sort2(nx.concatenate((u_values, v_values), 0), 0) + + cdf_diff = nx.cumsum(nx.take_along_axis(nx.concatenate((u_weights, -v_weights), 0), values_sorter, 0), 0) + cdf_diff_sorted, cdf_diff_sorter = nx.sort2(cdf_diff, axis=0) + + values_sorted = nx.zero_pad(values_sorted, pad_width=[(0, 1), (0, 0)], value=1) + delta = values_sorted[1:, ...] - values_sorted[:-1, ...] + weight_sorted = nx.take_along_axis(delta, cdf_diff_sorter, 0) + + sum_weights = nx.cumsum(weight_sorted, axis=0) - 0.5 + sum_weights[sum_weights < 0] = np.inf + inds = nx.argmin(sum_weights, axis=0) + + levMed = nx.take_along_axis(cdf_diff_sorted, nx.reshape(inds, (1, -1)), 0) + + return nx.sum(delta * nx.abs(cdf_diff - levMed), axis=0) + + +def wasserstein_circle(u_values, v_values, u_weights=None, v_weights=None, p=1, + Lm=10, Lp=10, tm=-1, tp=1, eps=1e-6, require_sort=True): + r"""Computes the Wasserstein distance on the circle using either [45] for p=1 or + the binary search algorithm proposed in [44] otherwise. + Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, + takes the value modulo 1. + If the values are on :math:`S^1\subset\mathbb{R}^2`, it requires to first find the coordinates + using e.g. the atan2 function. + + General loss returned: + + .. math:: + OT_{loss} = \inf_{\theta\in\mathbb{R}}\int_0^1 |cdf_u^{-1}(q) - (cdf_v-\theta)^{-1}(q)|^p\ \mathrm{d}q + + For p=1, [45] + + .. math:: + W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t + + For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with + + .. math:: + u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi} + + using e.g. ot.utils.get_coordinate_circle(x) + + The function runs on backend but tensorflow is not supported. + + Parameters + ---------- + u_values : ndarray, shape (n, ...) + samples in the source domain (coordinates on [0,1[) + v_values : ndarray, shape (n, ...) + samples in the target domain (coordinates on [0,1[) + u_weights : ndarray, shape (n, ...), optional + samples weights in the source domain + v_weights : ndarray, shape (n, ...), optional + samples weights in the target domain + p : float, optional (default=1) + Power p used for computing the Wasserstein distance + Lm : int, optional + Lower bound dC. For p>1. + Lp : int, optional + Upper bound dC. For p>1. + tm: float, optional + Lower bound theta. For p>1. + tp: float, optional + Upper bound theta. For p>1. + eps: float, optional + Stopping condition. For p>1. + require_sort: bool, optional + If True, sort the values. + + Returns + ------- + loss: float + Cost associated to the optimal transportation + + Examples + -------- + >>> u = np.array([[0.2,0.5,0.8]])%1 + >>> v = np.array([[0.4,0.5,0.7]])%1 + >>> wasserstein_circle(u.T, v.T) + array([0.1]) + + References + ---------- + .. [44] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. "The statistics of circular optimal transport." Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82. + .. [45] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. + """ + assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p) + + if p == 1: + return wasserstein1_circle(u_values, v_values, u_weights, v_weights, require_sort) + + return binary_search_circle(u_values, v_values, u_weights, v_weights, + p=p, Lm=Lm, Lp=Lp, tm=tm, tp=tp, eps=eps, + require_sort=require_sort) + + +def semidiscrete_wasserstein2_unif_circle(u_values, u_weights=None): + r"""Computes the closed-form for the 2-Wasserstein distance between samples and a uniform distribution on :math:`S^1` + Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, + takes the value modulo 1. + If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates + using e.g. the atan2 function. + + .. math:: + W_2^2(\mu_n, \nu) = \sum_{i=1}^n \alpha_i x_i^2 - \left(\sum_{i=1}^n \alpha_i x_i\right)^2 + \sum_{i=1}^n \alpha_i x_i \left(1-\alpha_i-2\sum_{k=1}^{i-1}\alpha_k\right) + \frac{1}{12} + + where: + + - :math:`\nu=\mathrm{Unif}(S^1)` and :math:`\mu_n = \sum_{i=1}^n \alpha_i \delta_{x_i}` + + For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with + + .. math:: + u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi}, + + using e.g. ot.utils.get_coordinate_circle(x) + + Parameters + ---------- + u_values: ndarray, shape (n, ...) + Samples + u_weights : ndarray, shape (n, ...), optional + samples weights in the source domain + + Returns + ------- + loss: float + Cost associated to the optimal transportation + + Examples + -------- + >>> x0 = np.array([[0], [0.2], [0.4]]) + >>> semidiscrete_wasserstein2_unif_circle(x0) + array([0.02111111]) + + References + ---------- + .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations. + """ + + if u_weights is not None: + nx = get_backend(u_values, u_weights) + else: + nx = get_backend(u_values) + + n = u_values.shape[0] + + u_values = u_values % 1 + + if len(u_values.shape) == 1: + u_values = nx.reshape(u_values, (n, 1)) + + if u_weights is None: + u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values) + elif u_weights.ndim != u_values.ndim: + u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) + + u_values = nx.sort(u_values, 0) + u_cdf = nx.cumsum(u_weights, 0) + u_cdf = nx.zero_pad(u_cdf, [(1, 0), (0, 0)]) + + cpt1 = nx.sum(u_weights * u_values**2, axis=0) + u_mean = nx.sum(u_weights * u_values, axis=0) + + ns = 1 - u_weights - 2 * u_cdf[:-1] + cpt2 = nx.sum(u_values * u_weights * ns, axis=0) + + return cpt1 - u_mean**2 + cpt2 + 1 / 12 diff --git a/ot/sliced.py b/ot/sliced.py index 20891a4..077ff0b 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -12,7 +12,8 @@ Sliced OT Distances import numpy as np from .backend import get_backend, NumpyBackend -from .utils import list_to_array +from .utils import list_to_array, get_coordinate_circle +from .lp import wasserstein_circle, semidiscrete_wasserstein2_unif_circle def get_random_projections(d, n_projections, seed=None, backend=None, type_as=None): @@ -107,7 +108,6 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, -------- >>> n_samples_a = 20 - >>> reg = 0.1 >>> X = np.random.normal(0., 1., (n_samples_a, 5)) >>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE 0.0 @@ -208,7 +208,6 @@ def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, -------- >>> n_samples_a = 20 - >>> reg = 0.1 >>> X = np.random.normal(0., 1., (n_samples_a, 5)) >>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE 0.0 @@ -258,3 +257,183 @@ def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, if log: return res, {"projections": projections, "projected_emds": projected_emd} return res + + +def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50, + p=2, seed=None, log=False): + r""" + Compute the spherical sliced-Wasserstein discrepancy. + + .. math:: + SSW_p(\mu,\nu) = \left(\int_{\mathbb{V}_{d,2}} W_p^p(P^U_\#\mu, P^U_\#\nu)\ \mathrm{d}\sigma(U)\right)^{\frac{1}{p}} + + where: + + - :math:`P^U_\# \mu` stands for the pushforwards of the projection :math:`\forall x\in S^{d-1},\ P^U(x) = \frac{U^Tx}{\|U^Tx\|_2}` + + The function runs on backend but tensorflow is not supported. + + Parameters + ---------- + X_s: ndarray, shape (n_samples_a, dim) + Samples in the source domain + X_t: ndarray, shape (n_samples_b, dim) + Samples in the target domain + a : ndarray, shape (n_samples_a,), optional + samples weights in the source domain + b : ndarray, shape (n_samples_b,), optional + samples weights in the target domain + n_projections : int, optional + Number of projections used for the Monte-Carlo approximation + p: float, optional (default=2) + Power p used for computing the spherical sliced Wasserstein + seed: int or RandomState or None, optional + Seed used for random number generator + log: bool, optional + if True, sliced_wasserstein_sphere returns the projections used and their associated EMD. + + Returns + ------- + cost: float + Spherical Sliced Wasserstein Cost + log: dict, optional + log dictionary return only if log==True in parameters + + Examples + -------- + >>> n_samples_a = 20 + >>> X = np.random.normal(0., 1., (n_samples_a, 5)) + >>> X = X / np.sqrt(np.sum(X**2, -1, keepdims=True)) + >>> sliced_wasserstein_sphere(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE + 0.0 + + References + ---------- + .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations. + """ + if a is not None and b is not None: + nx = get_backend(X_s, X_t, a, b) + else: + nx = get_backend(X_s, X_t) + + n, d = X_s.shape + m, _ = X_t.shape + + if X_s.shape[1] != X_t.shape[1]: + raise ValueError( + "X_s and X_t must have the same number of dimensions {} and {} respectively given".format(X_s.shape[1], + X_t.shape[1])) + if nx.any(nx.abs(nx.sum(X_s**2, axis=-1) - 1) > 10**(-4)): + raise ValueError("X_s is not on the sphere.") + if nx.any(nx.abs(nx.sum(X_t**2, axis=-1) - 1) > 10**(-4)): + raise ValueError("Xt is not on the sphere.") + + # Uniforms and independent samples on the Stiefel manifold V_{d,2} + if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy': + Z = seed.randn(n_projections, d, 2) + else: + if seed is not None: + nx.seed(seed) + Z = nx.randn(n_projections, d, 2, type_as=X_s) + + projections, _ = nx.qr(Z) + + # Projection on S^1 + # Projection on plane + Xps = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_s[:, :, None]), (n_projections, 2, n)), (0, 2, 1)) + Xpt = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_t[:, :, None]), (n_projections, 2, m)), (0, 2, 1)) + + # Projection on sphere + Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True)) + Xpt = Xpt / nx.sqrt(nx.sum(Xpt**2, -1, keepdims=True)) + + # Get coordinates on [0,1[ + Xps_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xps, (-1, 2))), (n_projections, n)) + Xpt_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xpt, (-1, 2))), (n_projections, m)) + + projected_emd = wasserstein_circle(Xps_coords.T, Xpt_coords.T, u_weights=a, v_weights=b, p=p) + res = nx.mean(projected_emd) ** (1 / p) + + if log: + return res, {"projections": projections, "projected_emds": projected_emd} + return res + + +def sliced_wasserstein_sphere_unif(X_s, a=None, n_projections=50, seed=None, log=False): + r"""Compute the 2-spherical sliced wasserstein w.r.t. a uniform distribution. + + .. math:: + SSW_2(\mu_n, \nu) + + where + + - :math:`\mu_n=\sum_{i=1}^n \alpha_i \delta_{x_i}` + - :math:`\nu=\mathrm{Unif}(S^1)` + + Parameters + ---------- + X_s: ndarray, shape (n_samples_a, dim) + Samples in the source domain + a : ndarray, shape (n_samples_a,), optional + samples weights in the source domain + n_projections : int, optional + Number of projections used for the Monte-Carlo approximation + seed: int or RandomState or None, optional + Seed used for random number generator + log: bool, optional + if True, sliced_wasserstein_distance returns the projections used and their associated EMD. + + Returns + ------- + cost: float + Spherical Sliced Wasserstein Cost + log: dict, optional + log dictionary return only if log==True in parameters + + Examples + --------- + >>> np.random.seed(42) + >>> x0 = np.random.randn(500,3) + >>> x0 = x0 / np.sqrt(np.sum(x0**2, -1, keepdims=True)) + >>> ssw = sliced_wasserstein_sphere_unif(x0, seed=42) + >>> np.allclose(sliced_wasserstein_sphere_unif(x0, seed=42), 0.01734, atol=1e-3) + True + + References: + ----------- + .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations. + """ + if a is not None: + nx = get_backend(X_s, a) + else: + nx = get_backend(X_s) + + n, d = X_s.shape + + if nx.any(nx.abs(nx.sum(X_s**2, axis=-1) - 1) > 10**(-4)): + raise ValueError("X_s is not on the sphere.") + + # Uniforms and independent samples on the Stiefel manifold V_{d,2} + if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy': + Z = seed.randn(n_projections, d, 2) + else: + if seed is not None: + nx.seed(seed) + Z = nx.randn(n_projections, d, 2, type_as=X_s) + + projections, _ = nx.qr(Z) + + # Projection on S^1 + # Projection on plane + Xps = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_s[:, :, None]), (n_projections, 2, n)), (0, 2, 1)) + # Projection on sphere + Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True)) + # Get coordinates on [0,1[ + Xps_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xps, (-1, 2))), (n_projections, n)) + + projected_emd = semidiscrete_wasserstein2_unif_circle(Xps_coords.T, u_weights=a) + res = nx.mean(projected_emd) ** (1 / 2) + + if log: + return res, {"projections": projections, "projected_emds": projected_emd} + return res diff --git a/ot/utils.py b/ot/utils.py index 9093f09..3423a7e 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -375,6 +375,36 @@ def check_random_state(seed): ' instance'.format(seed)) +def get_coordinate_circle(x): + r"""For :math:`x\in S^1 \subset \mathbb{R}^2`, returns the coordinates in + turn (in [0,1[). + + .. math:: + u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi} + + Parameters + ---------- + x: ndarray, shape (n, 2) + Samples on the circle with ambient coordinates + + Returns + ------- + x_t: ndarray, shape (n,) + Coordinates on [0,1[ + + Examples + -------- + >>> u = np.array([[0.2,0.5,0.8]]) * (2 * np.pi) + >>> x1, y1 = np.cos(u), np.sin(u) + >>> x = np.concatenate([x1, y1]).T + >>> get_coordinate_circle(x) + array([0.2, 0.5, 0.8]) + """ + nx = get_backend(x) + x_t = (nx.atan2(-x[:, 1], -x[:, 0]) + np.pi) / (2 * np.pi) + return x_t + + class deprecated(object): r"""Decorator to mark a function or class as deprecated. diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index 20f307a..21abd1d 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -218,3 +218,130 @@ def test_emd1d_device_tf(): nx.assert_same_dtype_device(xb, emd) nx.assert_same_dtype_device(xb, emd2) assert nx.dtype_device(emd)[1].startswith("GPU") + + +def test_wasserstein_1d_circle(): + # test binary_search_circle and wasserstein_circle give similar results as emd + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.rand(n,) + v = rng.rand(m,) + + w_u = rng.uniform(0., 1., n) + w_u = w_u / w_u.sum() + + w_v = rng.uniform(0., 1., m) + w_v = w_v / w_v.sum() + + M1 = np.minimum(np.abs(u[:, None] - v[None]), 1 - np.abs(u[:, None] - v[None])) + + wass1 = ot.emd2(w_u, w_v, M1) + + wass1_bsc = ot.binary_search_circle(u, v, w_u, w_v, p=1) + w1_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=1) + + M2 = M1**2 + wass2 = ot.emd2(w_u, w_v, M2) + wass2_bsc = ot.binary_search_circle(u, v, w_u, w_v, p=2) + w2_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=2) + + # check loss is similar + np.testing.assert_allclose(wass1, wass1_bsc) + np.testing.assert_allclose(wass1, w1_circle, rtol=1e-2) + np.testing.assert_allclose(wass2, wass2_bsc) + np.testing.assert_allclose(wass2, w2_circle) + + +@pytest.skip_backend("tf") +def test_wasserstein1d_circle_devices(nx): + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 1, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp) + + w1 = ot.wasserstein_circle(xb, xb, rho_ub, rho_vb, p=1) + w2_bsc = ot.wasserstein_circle(xb, xb, rho_ub, rho_vb, p=2) + + nx.assert_same_dtype_device(xb, w1) + nx.assert_same_dtype_device(xb, w2_bsc) + + +def test_wasserstein_1d_unif_circle(): + # test semidiscrete_wasserstein2_unif_circle versus wasserstein_circle + n = 20 + m = 50000 + + rng = np.random.RandomState(0) + u = rng.rand(n,) + v = rng.rand(m,) + + # w_u = rng.uniform(0., 1., n) + # w_u = w_u / w_u.sum() + + w_u = ot.utils.unif(n) + w_v = ot.utils.unif(m) + + M1 = np.minimum(np.abs(u[:, None] - v[None]), 1 - np.abs(u[:, None] - v[None])) + wass2 = ot.emd2(w_u, w_v, M1**2) + + wass2_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=2, eps=1e-15) + wass2_unif_circle = ot.semidiscrete_wasserstein2_unif_circle(u, w_u) + + # check loss is similar + np.testing.assert_allclose(wass2, wass2_unif_circle, atol=1e-3) + np.testing.assert_allclose(wass2_circle, wass2_unif_circle, atol=1e-3) + + +def test_wasserstein1d_unif_circle_devices(nx): + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 1, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb, rho_ub = nx.from_numpy(x, rho_u, type_as=tp) + + w2 = ot.semidiscrete_wasserstein2_unif_circle(xb, rho_ub) + + nx.assert_same_dtype_device(xb, w2) + + +def test_binary_search_circle_log(): + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.rand(n,) + v = rng.rand(m,) + + wass2_bsc, log = ot.binary_search_circle(u, v, p=2, log=True) + optimal_thetas = log["optimal_theta"] + + assert optimal_thetas.shape[0] == 1 + + +def test_wasserstein_circle_bad_shape(): + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.rand(n, 2) + v = rng.rand(m, 1) + + with pytest.raises(ValueError): + _ = ot.wasserstein_circle(u, v, p=2) + + with pytest.raises(ValueError): + _ = ot.wasserstein_circle(u, v, p=1) diff --git a/test/test_backend.py b/test/test_backend.py index 3628f61..fd9a761 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -282,6 +282,20 @@ def test_empty_backend(): nx.array_equal(M, M) with pytest.raises(NotImplementedError): nx.is_floating_point(M) + with pytest.raises(NotImplementedError): + nx.tile(M, (10, 1)) + with pytest.raises(NotImplementedError): + nx.floor(M) + with pytest.raises(NotImplementedError): + nx.prod(M) + with pytest.raises(NotImplementedError): + nx.sort2(M) + with pytest.raises(NotImplementedError): + nx.qr(M) + with pytest.raises(NotImplementedError): + nx.atan2(v, v) + with pytest.raises(NotImplementedError): + nx.transpose(M) def test_func_backends(nx): @@ -603,6 +617,38 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append("isfinite") + A = nx.tile(vb, (10, 1)) + lst_b.append(nx.to_numpy(A)) + lst_name.append("tile") + + A = nx.floor(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("floor") + + A = nx.prod(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("prod") + + A, B = nx.sort2(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("sort2 sort") + lst_b.append(nx.to_numpy(B)) + lst_name.append("sort2 argsort") + + A, B = nx.qr(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("QR Q") + lst_b.append(nx.to_numpy(B)) + lst_name.append("QR R") + + A = nx.atan2(vb, vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("atan2") + + A = nx.transpose(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("transpose") + assert not nx.array_equal(Mb, vb), "array_equal (shape)" assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true" assert not nx.array_equal( diff --git a/test/test_sliced.py b/test/test_sliced.py index eb13469..f54c799 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -266,3 +266,189 @@ def test_max_sliced_backend_device_tf(): valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb) nx.assert_same_dtype_device(xb, valb) assert nx.dtype_device(valb)[1].startswith("GPU") + + +def test_projections_stiefel(): + rng = np.random.RandomState(0) + + n_projs = 500 + x = np.random.randn(100, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + ssw, log = ot.sliced_wasserstein_sphere(x, x, n_projections=n_projs, + seed=rng, log=True) + + P = log["projections"] + P_T = np.transpose(P, [0, 2, 1]) + np.testing.assert_almost_equal(np.matmul(P_T, P), np.array([np.eye(2) for k in range(n_projs)])) + + +def test_sliced_sphere_same_dist(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + u = ot.utils.unif(n) + + res = ot.sliced_wasserstein_sphere(x, x, u, u, 10, seed=rng) + np.testing.assert_almost_equal(res, 0.) + + +def test_sliced_sphere_bad_shapes(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + y = rng.randn(n, 4) + y = y / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + u = ot.utils.unif(n) + + with pytest.raises(ValueError): + _ = ot.sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng) + + +def test_sliced_sphere_values_on_the_sphere(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + y = rng.randn(n, 4) + + u = ot.utils.unif(n) + + with pytest.raises(ValueError): + _ = ot.sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng) + + +def test_sliced_sphere_log(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 4) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + y = rng.randn(n, 4) + y = y / np.sqrt(np.sum(y**2, -1, keepdims=True)) + u = ot.utils.unif(n) + + res, log = ot.sliced_wasserstein_sphere(x, y, u, u, 10, p=1, seed=rng, log=True) + assert len(log) == 2 + projections = log["projections"] + projected_emds = log["projected_emds"] + + assert projections.shape[0] == len(projected_emds) == 10 + for emd in projected_emds: + assert emd > 0 + + +def test_sliced_sphere_different_dists(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + u = ot.utils.unif(n) + y = rng.randn(n, 3) + y = y / np.sqrt(np.sum(y**2, -1, keepdims=True)) + + res = ot.sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng) + assert res > 0. + + +def test_1d_sliced_sphere_equals_emd(): + n = 100 + m = 120 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + x_coords = (np.arctan2(-x[:, 1], -x[:, 0]) + np.pi) / (2 * np.pi) + a = rng.uniform(0, 1, n) + a /= a.sum() + + y = rng.randn(m, 2) + y = y / np.sqrt(np.sum(y**2, -1, keepdims=True)) + y_coords = (np.arctan2(-y[:, 1], -y[:, 0]) + np.pi) / (2 * np.pi) + u = ot.utils.unif(m) + + res = ot.sliced_wasserstein_sphere(x, y, a, u, 10, seed=42, p=2) + expected = ot.binary_search_circle(x_coords.T, y_coords.T, a, u, p=2) + + res1 = ot.sliced_wasserstein_sphere(x, y, a, u, 10, seed=42, p=1) + expected1 = ot.binary_search_circle(x_coords.T, y_coords.T, a, u, p=1) + + np.testing.assert_almost_equal(res ** 2, expected) + np.testing.assert_almost_equal(res1, expected1, decimal=3) + + +@pytest.skip_backend("tf") +def test_sliced_sphere_backend_type_devices(nx): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + y = rng.randn(2 * n, 3) + y = y / np.sqrt(np.sum(y**2, -1, keepdims=True)) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb, yb = nx.from_numpy(x, y, type_as=tp) + + valb = ot.sliced_wasserstein_sphere(xb, yb) + + nx.assert_same_dtype_device(xb, valb) + + +def test_sliced_sphere_unif_values_on_the_sphere(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + u = ot.utils.unif(n) + + with pytest.raises(ValueError): + _ = ot.sliced_wasserstein_sphere_unif(x, u, 10, seed=rng) + + +def test_sliced_sphere_unif_log(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 4) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + u = ot.utils.unif(n) + + res, log = ot.sliced_wasserstein_sphere_unif(x, u, 10, seed=rng, log=True) + assert len(log) == 2 + projections = log["projections"] + projected_emds = log["projected_emds"] + + assert projections.shape[0] == len(projected_emds) == 10 + for emd in projected_emds: + assert emd > 0 + + +def test_sliced_sphere_unif_backend_type_devices(nx): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb = nx.from_numpy(x, type_as=tp) + + valb = ot.sliced_wasserstein_sphere_unif(xb) + + nx.assert_same_dtype_device(xb, valb) diff --git a/test/test_utils.py b/test/test_utils.py index 666c157..31b12ef 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -330,3 +330,13 @@ def test_OTResult(): for at in lst_attributes: with pytest.raises(NotImplementedError): getattr(res, at) + + +def test_get_coordinate_circle(): + + u = np.random.rand(1, 100) + x1, y1 = np.cos(u * (2 * np.pi)), np.sin(u * (2 * np.pi)) + x = np.concatenate([x1, y1]).T + x_p = ot.utils.get_coordinate_circle(x) + + np.testing.assert_allclose(u[0], x_p) -- cgit v1.2.3