summaryrefslogtreecommitdiff
path: root/ot/lp/solver_1d.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/lp/solver_1d.py')
-rw-r--r--ot/lp/solver_1d.py629
1 files changed, 627 insertions, 2 deletions
diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py
index 43763a9..bcfc920 100644
--- a/ot/lp/solver_1d.py
+++ b/ot/lp/solver_1d.py
@@ -53,7 +53,7 @@ def wasserstein_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, requ
distributions
.. math:
- OT_{loss} = \int_0^1 |cdf_u^{-1}(q) cdf_v^{-1}(q)|^p dq
+ OT_{loss} = \int_0^1 |cdf_u^{-1}(q) - cdf_v^{-1}(q)|^p dq
It is formally the p-Wasserstein distance raised to the power p.
We do so in a vectorized way by first building the individual quantile functions then integrating them.
@@ -129,7 +129,7 @@ def wasserstein_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, requ
diff_quantiles = nx.abs(u_quantiles - v_quantiles)
if p == 1:
- return nx.sum(delta * nx.abs(diff_quantiles), axis=0)
+ return nx.sum(delta * diff_quantiles, axis=0)
return nx.sum(delta * nx.power(diff_quantiles, p), axis=0)
@@ -365,3 +365,628 @@ def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
log_emd = {'G': G}
return cost, log_emd
return cost
+
+
+def roll_cols(M, shifts):
+ r"""
+ Utils functions which allow to shift the order of each row of a 2d matrix
+
+ Parameters
+ ----------
+ M : (nr, nc) ndarray
+ Matrix to shift
+ shifts: int or (nr,) ndarray
+
+ Returns
+ -------
+ Shifted array
+
+ Examples
+ --------
+ >>> M = np.array([[1,2,3],[4,5,6],[7,8,9]])
+ >>> roll_cols(M, 2)
+ array([[2, 3, 1],
+ [5, 6, 4],
+ [8, 9, 7]])
+ >>> roll_cols(M, np.array([[1],[2],[1]]))
+ array([[3, 1, 2],
+ [5, 6, 4],
+ [9, 7, 8]])
+
+ References
+ ----------
+ https://stackoverflow.com/questions/66596699/how-to-shift-columns-or-rows-in-a-tensor-with-different-offsets-in-pytorch
+ """
+ nx = get_backend(M)
+
+ n_rows, n_cols = M.shape
+
+ arange1 = nx.tile(nx.reshape(nx.arange(n_cols), (1, n_cols)), (n_rows, 1))
+ arange2 = (arange1 - shifts) % n_cols
+
+ return nx.take_along_axis(M, arange2, 1)
+
+
+def derivative_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p=2):
+ r""" Computes the left and right derivative of the cost (Equation (6.3) and (6.4) of [1])
+
+ Parameters
+ ----------
+ theta: array-like, shape (n_batch, n)
+ Cuts on the circle
+ u_values: array-like, shape (n_batch, n)
+ locations of the first empirical distribution
+ v_values: array-like, shape (n_batch, n)
+ locations of the second empirical distribution
+ u_cdf: array-like, shape (n_batch, n)
+ cdf of the first empirical distribution
+ v_cdf: array-like, shape (n_batch, n)
+ cdf of the second empirical distribution
+ p: float, optional = 2
+ Power p used for computing the Wasserstein distance
+
+ Returns
+ -------
+ dCp: array-like, shape (n_batch, 1)
+ The batched right derivative
+ dCm: array-like, shape (n_batch, 1)
+ The batched left derivative
+
+ References
+ ---------
+ .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258.
+ """
+ nx = get_backend(theta, u_values, v_values, u_cdf, v_cdf)
+
+ v_values = nx.copy(v_values)
+
+ n = u_values.shape[-1]
+ m_batch, m = v_values.shape
+
+ v_cdf_theta = v_cdf - (theta - nx.floor(theta))
+
+ mask_p = v_cdf_theta >= 0
+ mask_n = v_cdf_theta < 0
+
+ v_values[mask_n] += nx.floor(theta)[mask_n] + 1
+ v_values[mask_p] += nx.floor(theta)[mask_p]
+
+ if nx.any(mask_n) and nx.any(mask_p):
+ v_cdf_theta[mask_n] += 1
+
+ v_cdf_theta2 = nx.copy(v_cdf_theta)
+ v_cdf_theta2[mask_n] = np.inf
+ shift = (-nx.argmin(v_cdf_theta2, axis=-1))
+
+ v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1, 1)))
+ v_values = roll_cols(v_values, nx.reshape(shift, (-1, 1)))
+ v_values = nx.concatenate([v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1)
+
+ if nx.__name__ == 'torch':
+ # this is to ensure the best performance for torch searchsorted
+ # and avoid a warninng related to non-contiguous arrays
+ u_cdf = u_cdf.contiguous()
+ v_cdf_theta = v_cdf_theta.contiguous()
+
+ # quantiles of F_u evaluated in F_v^\theta
+ u_index = nx.searchsorted(u_cdf, v_cdf_theta)
+ u_icdf_theta = nx.take_along_axis(u_values, nx.clip(u_index, 0, n - 1), -1)
+
+ # Deal with 1
+ u_cdfm = nx.concatenate([u_cdf, nx.reshape(u_cdf[:, 0], (-1, 1)) + 1], axis=1)
+ u_valuesm = nx.concatenate([u_values, nx.reshape(u_values[:, 0], (-1, 1)) + 1], axis=1)
+
+ if nx.__name__ == 'torch':
+ # this is to ensure the best performance for torch searchsorted
+ # and avoid a warninng related to non-contiguous arrays
+ u_cdfm = u_cdfm.contiguous()
+ v_cdf_theta = v_cdf_theta.contiguous()
+
+ u_indexm = nx.searchsorted(u_cdfm, v_cdf_theta, side="right")
+ u_icdfm_theta = nx.take_along_axis(u_valuesm, nx.clip(u_indexm, 0, n), -1)
+
+ dCp = nx.sum(nx.power(nx.abs(u_icdf_theta - v_values[:, 1:]), p)
+ - nx.power(nx.abs(u_icdf_theta - v_values[:, :-1]), p), axis=-1)
+
+ dCm = nx.sum(nx.power(nx.abs(u_icdfm_theta - v_values[:, 1:]), p)
+ - nx.power(nx.abs(u_icdfm_theta - v_values[:, :-1]), p), axis=-1)
+
+ return dCp.reshape(-1, 1), dCm.reshape(-1, 1)
+
+
+def ot_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p):
+ r""" Computes the the cost (Equation (6.2) of [1])
+
+ Parameters
+ ----------
+ theta: array-like, shape (n_batch, n)
+ Cuts on the circle
+ u_values: array-like, shape (n_batch, n)
+ locations of the first empirical distribution
+ v_values: array-like, shape (n_batch, n)
+ locations of the second empirical distribution
+ u_cdf: array-like, shape (n_batch, n)
+ cdf of the first empirical distribution
+ v_cdf: array-like, shape (n_batch, n)
+ cdf of the second empirical distribution
+ p: float, optional = 2
+ Power p used for computing the Wasserstein distance
+
+ Returns
+ -------
+ ot_cost: array-like, shape (n_batch,)
+ OT cost evaluated at theta
+
+ References
+ ---------
+ .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258.
+ """
+ nx = get_backend(theta, u_values, v_values, u_cdf, v_cdf)
+
+ v_values = nx.copy(v_values)
+
+ m_batch, m = v_values.shape
+ n_batch, n = u_values.shape
+
+ v_cdf_theta = v_cdf - (theta - nx.floor(theta))
+
+ mask_p = v_cdf_theta >= 0
+ mask_n = v_cdf_theta < 0
+
+ v_values[mask_n] += nx.floor(theta)[mask_n] + 1
+ v_values[mask_p] += nx.floor(theta)[mask_p]
+
+ if nx.any(mask_n) and nx.any(mask_p):
+ v_cdf_theta[mask_n] += 1
+
+ # Put negative values at the end
+ v_cdf_theta2 = nx.copy(v_cdf_theta)
+ v_cdf_theta2[mask_n] = np.inf
+ shift = (-nx.argmin(v_cdf_theta2, axis=-1))
+
+ v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1, 1)))
+ v_values = roll_cols(v_values, nx.reshape(shift, (-1, 1)))
+ v_values = nx.concatenate([v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1)
+
+ # Compute absciss
+ cdf_axis = nx.sort(nx.concatenate((u_cdf, v_cdf_theta), -1), -1)
+ cdf_axis_pad = nx.zero_pad(cdf_axis, pad_width=[(0, 0), (1, 0)])
+
+ delta = cdf_axis_pad[..., 1:] - cdf_axis_pad[..., :-1]
+
+ if nx.__name__ == 'torch':
+ # this is to ensure the best performance for torch searchsorted
+ # and avoid a warninng related to non-contiguous arrays
+ u_cdf = u_cdf.contiguous()
+ v_cdf_theta = v_cdf_theta.contiguous()
+ cdf_axis = cdf_axis.contiguous()
+
+ # Compute icdf
+ u_index = nx.searchsorted(u_cdf, cdf_axis)
+ u_icdf = nx.take_along_axis(u_values, u_index.clip(0, n - 1), -1)
+
+ v_values = nx.concatenate([v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1)
+ v_index = nx.searchsorted(v_cdf_theta, cdf_axis)
+ v_icdf = nx.take_along_axis(v_values, v_index.clip(0, m), -1)
+
+ if p == 1:
+ ot_cost = nx.sum(delta * nx.abs(u_icdf - v_icdf), axis=-1)
+ else:
+ ot_cost = nx.sum(delta * nx.power(nx.abs(u_icdf - v_icdf), p), axis=-1)
+
+ return ot_cost
+
+
+def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1,
+ Lm=10, Lp=10, tm=-1, tp=1, eps=1e-6, require_sort=True,
+ log=False):
+ r"""Computes the Wasserstein distance on the circle using the Binary search algorithm proposed in [44].
+ Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`,
+ takes the value modulo 1.
+ If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates
+ using e.g. the atan2 function.
+
+ .. math::
+ W_p^p(u,v) = \inf_{\theta\in\mathbb{R}}\int_0^1 |F_u^{-1}(q) - (F_v-\theta)^{-1}(q)|^p\ \mathrm{d}q
+
+ where:
+
+ - :math:`F_u` and :math:`F_v` are respectively the cdfs of :math:`u` and :math:`v`
+
+ For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with
+
+ .. math::
+ u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi}
+
+ using e.g. ot.utils.get_coordinate_circle(x)
+
+ The function runs on backend but tensorflow is not supported.
+
+ Parameters
+ ----------
+ u_values : ndarray, shape (n, ...)
+ samples in the source domain (coordinates on [0,1[)
+ v_values : ndarray, shape (n, ...)
+ samples in the target domain (coordinates on [0,1[)
+ u_weights : ndarray, shape (n, ...), optional
+ samples weights in the source domain
+ v_weights : ndarray, shape (n, ...), optional
+ samples weights in the target domain
+ p : float, optional (default=1)
+ Power p used for computing the Wasserstein distance
+ Lm : int, optional
+ Lower bound dC
+ Lp : int, optional
+ Upper bound dC
+ tm: float, optional
+ Lower bound theta
+ tp: float, optional
+ Upper bound theta
+ eps: float, optional
+ Stopping condition
+ require_sort: bool, optional
+ If True, sort the values.
+ log: bool, optional
+ If True, returns also the optimal theta
+
+ Returns
+ -------
+ loss: float
+ Cost associated to the optimal transportation
+ log: dict, optional
+ log dictionary returned only if log==True in parameters
+
+ Examples
+ --------
+ >>> u = np.array([[0.2,0.5,0.8]])%1
+ >>> v = np.array([[0.4,0.5,0.7]])%1
+ >>> binary_search_circle(u.T, v.T, p=1)
+ array([0.1])
+
+ References
+ ----------
+ .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258.
+ .. Matlab Code: https://users.mccme.ru/ansobol/otarie/software.html
+ """
+ assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p)
+
+ if u_weights is not None and v_weights is not None:
+ nx = get_backend(u_values, v_values, u_weights, v_weights)
+ else:
+ nx = get_backend(u_values, v_values)
+
+ n = u_values.shape[0]
+ m = v_values.shape[0]
+
+ if len(u_values.shape) == 1:
+ u_values = nx.reshape(u_values, (n, 1))
+ if len(v_values.shape) == 1:
+ v_values = nx.reshape(v_values, (m, 1))
+
+ if u_values.shape[1] != v_values.shape[1]:
+ raise ValueError(
+ "u and v must have the same number of batchs {} and {} respectively given".format(u_values.shape[1],
+ v_values.shape[1]))
+
+ u_values = u_values % 1
+ v_values = v_values % 1
+
+ if u_weights is None:
+ u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values)
+ elif u_weights.ndim != u_values.ndim:
+ u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1)
+ if v_weights is None:
+ v_weights = nx.full(v_values.shape, 1. / m, type_as=v_values)
+ elif v_weights.ndim != v_values.ndim:
+ v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1)
+
+ if require_sort:
+ u_sorter = nx.argsort(u_values, 0)
+ u_values = nx.take_along_axis(u_values, u_sorter, 0)
+
+ v_sorter = nx.argsort(v_values, 0)
+ v_values = nx.take_along_axis(v_values, v_sorter, 0)
+
+ u_weights = nx.take_along_axis(u_weights, u_sorter, 0)
+ v_weights = nx.take_along_axis(v_weights, v_sorter, 0)
+
+ u_cdf = nx.cumsum(u_weights, 0).T
+ v_cdf = nx.cumsum(v_weights, 0).T
+
+ u_values = u_values.T
+ v_values = v_values.T
+
+ L = max(Lm, Lp)
+
+ tm = tm * nx.reshape(nx.ones((u_values.shape[0],), type_as=u_values), (-1, 1))
+ tm = nx.tile(tm, (1, m))
+ tp = tp * nx.reshape(nx.ones((u_values.shape[0],), type_as=u_values), (-1, 1))
+ tp = nx.tile(tp, (1, m))
+ tc = (tm + tp) / 2
+
+ done = nx.zeros((u_values.shape[0], m))
+
+ cpt = 0
+ while nx.any(1 - done):
+ cpt += 1
+
+ dCp, dCm = derivative_cost_on_circle(tc, u_values, v_values, u_cdf, v_cdf, p)
+ done = ((dCp * dCm) <= 0) * 1
+
+ mask = ((tp - tm) < eps / L) * (1 - done)
+
+ if nx.any(mask):
+ # can probably be improved by computing only relevant values
+ dCptp, dCmtp = derivative_cost_on_circle(tp, u_values, v_values, u_cdf, v_cdf, p)
+ dCptm, dCmtm = derivative_cost_on_circle(tm, u_values, v_values, u_cdf, v_cdf, p)
+ Ctm = ot_cost_on_circle(tm, u_values, v_values, u_cdf, v_cdf, p).reshape(-1, 1)
+ Ctp = ot_cost_on_circle(tp, u_values, v_values, u_cdf, v_cdf, p).reshape(-1, 1)
+
+ mask_end = mask * (nx.abs(dCptm - dCmtp) > 0.001)
+ tc[mask_end > 0] = ((Ctp - Ctm + tm * dCptm - tp * dCmtp) / (dCptm - dCmtp))[mask_end > 0]
+ done[nx.prod(mask, axis=-1) > 0] = 1
+ elif nx.any(1 - done):
+ tm[((1 - mask) * (dCp < 0)) > 0] = tc[((1 - mask) * (dCp < 0)) > 0]
+ tp[((1 - mask) * (dCp >= 0)) > 0] = tc[((1 - mask) * (dCp >= 0)) > 0]
+ tc[((1 - mask) * (1 - done)) > 0] = (tm[((1 - mask) * (1 - done)) > 0] + tp[((1 - mask) * (1 - done)) > 0]) / 2
+
+ w = ot_cost_on_circle(tc, u_values, v_values, u_cdf, v_cdf, p)
+
+ if log:
+ return w, {"optimal_theta": tc[:, 0]}
+ return w
+
+
+def wasserstein1_circle(u_values, v_values, u_weights=None, v_weights=None, require_sort=True):
+ r"""Computes the 1-Wasserstein distance on the circle using the level median [45].
+ Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`,
+ takes the value modulo 1.
+ If the values are on :math:`S^1\subset\mathbb{R}^2`, first find the coordinates
+ using e.g. the atan2 function.
+ The function runs on backend but tensorflow is not supported.
+
+ .. math::
+ W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t
+
+ Parameters
+ ----------
+ u_values : ndarray, shape (n, ...)
+ samples in the source domain (coordinates on [0,1[)
+ v_values : ndarray, shape (n, ...)
+ samples in the target domain (coordinates on [0,1[)
+ u_weights : ndarray, shape (n, ...), optional
+ samples weights in the source domain
+ v_weights : ndarray, shape (n, ...), optional
+ samples weights in the target domain
+ require_sort: bool, optional
+ If True, sort the values.
+
+ Returns
+ -------
+ loss: float
+ Cost associated to the optimal transportation
+
+ Examples
+ --------
+ >>> u = np.array([[0.2,0.5,0.8]])%1
+ >>> v = np.array([[0.4,0.5,0.7]])%1
+ >>> wasserstein1_circle(u.T, v.T)
+ array([0.1])
+
+ References
+ ----------
+ .. [45] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. "The statistics of circular optimal transport." Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82.
+ .. Code R: https://gitlab.gwdg.de/shundri/circularOT/-/tree/master/
+ """
+
+ if u_weights is not None and v_weights is not None:
+ nx = get_backend(u_values, v_values, u_weights, v_weights)
+ else:
+ nx = get_backend(u_values, v_values)
+
+ n = u_values.shape[0]
+ m = v_values.shape[0]
+
+ if len(u_values.shape) == 1:
+ u_values = nx.reshape(u_values, (n, 1))
+ if len(v_values.shape) == 1:
+ v_values = nx.reshape(v_values, (m, 1))
+
+ if u_values.shape[1] != v_values.shape[1]:
+ raise ValueError(
+ "u and v must have the same number of batchs {} and {} respectively given".format(u_values.shape[1],
+ v_values.shape[1]))
+
+ u_values = u_values % 1
+ v_values = v_values % 1
+
+ if u_weights is None:
+ u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values)
+ elif u_weights.ndim != u_values.ndim:
+ u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1)
+ if v_weights is None:
+ v_weights = nx.full(v_values.shape, 1. / m, type_as=v_values)
+ elif v_weights.ndim != v_values.ndim:
+ v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1)
+
+ if require_sort:
+ u_sorter = nx.argsort(u_values, 0)
+ u_values = nx.take_along_axis(u_values, u_sorter, 0)
+
+ v_sorter = nx.argsort(v_values, 0)
+ v_values = nx.take_along_axis(v_values, v_sorter, 0)
+
+ u_weights = nx.take_along_axis(u_weights, u_sorter, 0)
+ v_weights = nx.take_along_axis(v_weights, v_sorter, 0)
+
+ # Code inspired from https://gitlab.gwdg.de/shundri/circularOT/-/tree/master/
+ values_sorted, values_sorter = nx.sort2(nx.concatenate((u_values, v_values), 0), 0)
+
+ cdf_diff = nx.cumsum(nx.take_along_axis(nx.concatenate((u_weights, -v_weights), 0), values_sorter, 0), 0)
+ cdf_diff_sorted, cdf_diff_sorter = nx.sort2(cdf_diff, axis=0)
+
+ values_sorted = nx.zero_pad(values_sorted, pad_width=[(0, 1), (0, 0)], value=1)
+ delta = values_sorted[1:, ...] - values_sorted[:-1, ...]
+ weight_sorted = nx.take_along_axis(delta, cdf_diff_sorter, 0)
+
+ sum_weights = nx.cumsum(weight_sorted, axis=0) - 0.5
+ sum_weights[sum_weights < 0] = np.inf
+ inds = nx.argmin(sum_weights, axis=0)
+
+ levMed = nx.take_along_axis(cdf_diff_sorted, nx.reshape(inds, (1, -1)), 0)
+
+ return nx.sum(delta * nx.abs(cdf_diff - levMed), axis=0)
+
+
+def wasserstein_circle(u_values, v_values, u_weights=None, v_weights=None, p=1,
+ Lm=10, Lp=10, tm=-1, tp=1, eps=1e-6, require_sort=True):
+ r"""Computes the Wasserstein distance on the circle using either [45] for p=1 or
+ the binary search algorithm proposed in [44] otherwise.
+ Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`,
+ takes the value modulo 1.
+ If the values are on :math:`S^1\subset\mathbb{R}^2`, it requires to first find the coordinates
+ using e.g. the atan2 function.
+
+ General loss returned:
+
+ .. math::
+ OT_{loss} = \inf_{\theta\in\mathbb{R}}\int_0^1 |cdf_u^{-1}(q) - (cdf_v-\theta)^{-1}(q)|^p\ \mathrm{d}q
+
+ For p=1, [45]
+
+ .. math::
+ W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t
+
+ For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with
+
+ .. math::
+ u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi}
+
+ using e.g. ot.utils.get_coordinate_circle(x)
+
+ The function runs on backend but tensorflow is not supported.
+
+ Parameters
+ ----------
+ u_values : ndarray, shape (n, ...)
+ samples in the source domain (coordinates on [0,1[)
+ v_values : ndarray, shape (n, ...)
+ samples in the target domain (coordinates on [0,1[)
+ u_weights : ndarray, shape (n, ...), optional
+ samples weights in the source domain
+ v_weights : ndarray, shape (n, ...), optional
+ samples weights in the target domain
+ p : float, optional (default=1)
+ Power p used for computing the Wasserstein distance
+ Lm : int, optional
+ Lower bound dC. For p>1.
+ Lp : int, optional
+ Upper bound dC. For p>1.
+ tm: float, optional
+ Lower bound theta. For p>1.
+ tp: float, optional
+ Upper bound theta. For p>1.
+ eps: float, optional
+ Stopping condition. For p>1.
+ require_sort: bool, optional
+ If True, sort the values.
+
+ Returns
+ -------
+ loss: float
+ Cost associated to the optimal transportation
+
+ Examples
+ --------
+ >>> u = np.array([[0.2,0.5,0.8]])%1
+ >>> v = np.array([[0.4,0.5,0.7]])%1
+ >>> wasserstein_circle(u.T, v.T)
+ array([0.1])
+
+ References
+ ----------
+ .. [44] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. "The statistics of circular optimal transport." Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82.
+ .. [45] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258.
+ """
+ assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p)
+
+ if p == 1:
+ return wasserstein1_circle(u_values, v_values, u_weights, v_weights, require_sort)
+
+ return binary_search_circle(u_values, v_values, u_weights, v_weights,
+ p=p, Lm=Lm, Lp=Lp, tm=tm, tp=tp, eps=eps,
+ require_sort=require_sort)
+
+
+def semidiscrete_wasserstein2_unif_circle(u_values, u_weights=None):
+ r"""Computes the closed-form for the 2-Wasserstein distance between samples and a uniform distribution on :math:`S^1`
+ Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`,
+ takes the value modulo 1.
+ If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates
+ using e.g. the atan2 function.
+
+ .. math::
+ W_2^2(\mu_n, \nu) = \sum_{i=1}^n \alpha_i x_i^2 - \left(\sum_{i=1}^n \alpha_i x_i\right)^2 + \sum_{i=1}^n \alpha_i x_i \left(1-\alpha_i-2\sum_{k=1}^{i-1}\alpha_k\right) + \frac{1}{12}
+
+ where:
+
+ - :math:`\nu=\mathrm{Unif}(S^1)` and :math:`\mu_n = \sum_{i=1}^n \alpha_i \delta_{x_i}`
+
+ For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with
+
+ .. math::
+ u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi},
+
+ using e.g. ot.utils.get_coordinate_circle(x)
+
+ Parameters
+ ----------
+ u_values: ndarray, shape (n, ...)
+ Samples
+ u_weights : ndarray, shape (n, ...), optional
+ samples weights in the source domain
+
+ Returns
+ -------
+ loss: float
+ Cost associated to the optimal transportation
+
+ Examples
+ --------
+ >>> x0 = np.array([[0], [0.2], [0.4]])
+ >>> semidiscrete_wasserstein2_unif_circle(x0)
+ array([0.02111111])
+
+ References
+ ----------
+ .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations.
+ """
+
+ if u_weights is not None:
+ nx = get_backend(u_values, u_weights)
+ else:
+ nx = get_backend(u_values)
+
+ n = u_values.shape[0]
+
+ u_values = u_values % 1
+
+ if len(u_values.shape) == 1:
+ u_values = nx.reshape(u_values, (n, 1))
+
+ if u_weights is None:
+ u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values)
+ elif u_weights.ndim != u_values.ndim:
+ u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1)
+
+ u_values = nx.sort(u_values, 0)
+ u_cdf = nx.cumsum(u_weights, 0)
+ u_cdf = nx.zero_pad(u_cdf, [(1, 0), (0, 0)])
+
+ cpt1 = nx.sum(u_weights * u_values**2, axis=0)
+ u_mean = nx.sum(u_weights * u_values, axis=0)
+
+ ns = 1 - u_weights - 2 * u_cdf[:-1]
+ cpt2 = nx.sum(u_values * u_weights * ns, axis=0)
+
+ return cpt1 - u_mean**2 + cpt2 + 1 / 12