summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2022-12-15 09:28:01 +0100
committerGitHub <noreply@github.com>2022-12-15 09:28:01 +0100
commit0411ea22a96f9c22af30156b45c16ef39ffb520d (patch)
tree7c131ad804d5b16a8c362c2fe296350a770400df
parent8490196dcc982c492b7565e1ec4de5f75f006acf (diff)
[MRG] New API for OT solver (with pre-computed ground cost matrix) (#388)
* new API for OT solver * use itertools for product of parameters * add tests for result class * add tests for result class * add tests for result class last time? * add sinkhorn * make partial OT bckend compatible * add TV as unbalanced flavor * better tests * make smoth backend compatible and add l2 tregularizatio to solve * add reularizedd unbalanced * add test for more complex attibutes * add test for more complex attibutes * add generic unbalaned solver and implement it for ot.solve * add entropy to possible regularization * star of documentation for ot.solv * weird new pep8 * documenttaion for function ot.solve done * pep8 * Update ot/solvers.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * update release file * Apply suggestions from code review Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * add test NotImplemented * pep8 * pep8gcmp pep8! * compute kl in backend * debug tensorflow kl backend Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
-rw-r--r--RELEASES.md3
-rw-r--r--ot/__init__.py7
-rw-r--r--ot/backend.py30
-rwxr-xr-xot/partial.py47
-rw-r--r--ot/smooth.py11
-rw-r--r--ot/solvers.py347
-rw-r--r--ot/unbalanced.py189
-rw-r--r--ot/utils.py202
-rw-r--r--test/test_backend.py6
-rwxr-xr-xtest/test_partial.py2
-rw-r--r--test/test_solvers.py133
-rw-r--r--test/test_unbalanced.py23
-rw-r--r--test/test_utils.py29
13 files changed, 1011 insertions, 18 deletions
diff --git a/RELEASES.md b/RELEASES.md
index 3bd84c1..9cfdd35 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -6,6 +6,9 @@
- Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376)
- Added Free Support Sinkhorn Barycenter + example (PR #387)
+- New API for OT solver using function `ot.solve` (PR #388)
+- Backend version of `ot.partial` and `ot.smooth` (PR #388)
+
#### Closed issues
diff --git a/ot/__init__.py b/ot/__init__.py
index 15d8351..51eb726 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -34,6 +34,7 @@ from . import backend
from . import regpath
from . import weak
from . import factored
+from . import solvers
# OT functions
from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
@@ -46,7 +47,7 @@ from .gromov import (gromov_wasserstein, gromov_wasserstein2,
gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2)
from .weak import weak_optimal_transport
from .factored import factored_optimal_transport
-
+from .solvers import solve
# utils functions
from .utils import dist, unif, tic, toc, toq
@@ -61,5 +62,5 @@ __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils',
'sinkhorn_unbalanced2', 'sliced_wasserstein_distance',
'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2',
'max_sliced_wasserstein_distance', 'weak_optimal_transport',
- 'factored_optimal_transport',
- 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath']
+ 'factored_optimal_transport', 'solve',
+ 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers']
diff --git a/ot/backend.py b/ot/backend.py
index e4b48e1..337e040 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -854,6 +854,21 @@ class Backend():
"""
raise NotImplementedError()
+ def kl_div(self, p, q, eps=1e-16):
+ r"""
+ Computes the Kullback-Leibler divergence.
+
+ This function follows the api from :any:`scipy.stats.entropy`.
+
+ Parameter eps is used to avoid numerical errors and is added in the log.
+
+ .. math::
+ KL(p,q) = \sum_i p(i) \log (\frac{p(i)}{q(i)}+\epsilon)
+
+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.entropy.html
+ """
+ raise NotImplementedError()
+
def isfinite(self, a):
r"""
Tests element-wise for finiteness (not infinity and not Not a Number).
@@ -1158,6 +1173,9 @@ class NumpyBackend(Backend):
def sqrtm(self, a):
return scipy.linalg.sqrtm(a)
+ def kl_div(self, p, q, eps=1e-16):
+ return np.sum(p * np.log(p / q + eps))
+
def isfinite(self, a):
return np.isfinite(a)
@@ -1481,6 +1499,9 @@ class JaxBackend(Backend):
L, V = jnp.linalg.eigh(a)
return (V * jnp.sqrt(L)[None, :]) @ V.T
+ def kl_div(self, p, q, eps=1e-16):
+ return jnp.sum(p * jnp.log(p / q + eps))
+
def isfinite(self, a):
return jnp.isfinite(a)
@@ -1901,6 +1922,9 @@ class TorchBackend(Backend):
L, V = torch.linalg.eigh(a)
return (V * torch.sqrt(L)[None, :]) @ V.T
+ def kl_div(self, p, q, eps=1e-16):
+ return torch.sum(p * torch.log(p / q + eps))
+
def isfinite(self, a):
return torch.isfinite(a)
@@ -2248,6 +2272,9 @@ class CupyBackend(Backend): # pragma: no cover
L, V = cp.linalg.eigh(a)
return (V * self.sqrt(L)[None, :]) @ V.T
+ def kl_div(self, p, q, eps=1e-16):
+ return cp.sum(p * cp.log(p / q + eps))
+
def isfinite(self, a):
return cp.isfinite(a)
@@ -2608,6 +2635,9 @@ class TensorflowBackend(Backend):
def sqrtm(self, a):
return tf.linalg.sqrtm(a)
+ def kl_div(self, p, q, eps=1e-16):
+ return tnp.sum(p * tnp.log(p / q + eps))
+
def isfinite(self, a):
return tnp.isfinite(a)
diff --git a/ot/partial.py b/ot/partial.py
index 0a9e450..eae91c4 100755
--- a/ot/partial.py
+++ b/ot/partial.py
@@ -8,6 +8,8 @@ Partial OT solvers
import numpy as np
from .lp import emd
+from .backend import get_backend
+from .utils import list_to_array
def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
@@ -114,14 +116,22 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
ot.partial.partial_wasserstein : Partial Wasserstein with fixed mass
"""
- if np.sum(a) > 1 or np.sum(b) > 1:
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(a, b, M)
+
+ if nx.sum(a) > 1 or nx.sum(b) > 1:
raise ValueError("Problem infeasible. Check that a and b are in the "
"simplex")
if reg_m is None:
- reg_m = np.max(M) + 1
- if reg_m < -np.max(M):
- return np.zeros((len(a), len(b)))
+ reg_m = float(nx.max(M)) + 1
+ if reg_m < -nx.max(M):
+ return nx.zeros((len(a), len(b)), type_as=M)
+
+ a0, b0, M0 = a, b, M
+ # convert to humpy
+ a, b, M = nx.to_numpy(a, b, M)
eps = 1e-20
M = np.asarray(M, dtype=np.float64)
@@ -149,10 +159,16 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
gamma = np.zeros((len(a), len(b)))
gamma[np.ix_(idx_x, idx_y)] = gamma_extended[:-nb_dummies, :-nb_dummies]
+ # convert back to backend
+ gamma = nx.from_numpy(gamma, type_as=M0)
+
if log_emd['warning'] is not None:
raise ValueError("Error in the EMD resolution: try to increase the"
" number of dummy points")
- log_emd['cost'] = np.sum(gamma * M)
+ log_emd['cost'] = nx.sum(gamma * M0)
+ log_emd['u'] = nx.from_numpy(log_emd['u'], type_as=a0)
+ log_emd['v'] = nx.from_numpy(log_emd['v'], type_as=b0)
+
if log:
return gamma, log_emd
else:
@@ -250,15 +266,23 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
entropic regularization parameter
"""
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(a, b, M)
+
if m is None:
return partial_wasserstein_lagrange(a, b, M, log=log, **kwargs)
elif m < 0:
raise ValueError("Problem infeasible. Parameter m should be greater"
" than 0.")
- elif m > np.min((np.sum(a), np.sum(b))):
+ elif m > nx.min((nx.sum(a), nx.sum(b))):
raise ValueError("Problem infeasible. Parameter m should lower or"
" equal than min(|a|_1, |b|_1).")
+ a0, b0, M0 = a, b, M
+ # convert to humpy
+ a, b, M = nx.to_numpy(a, b, M)
+
b_extended = np.append(b, [(np.sum(a) - m) / nb_dummies] * nb_dummies)
a_extended = np.append(a, [(np.sum(b) - m) / nb_dummies] * nb_dummies)
M_extended = np.zeros((len(a_extended), len(b_extended)))
@@ -267,15 +291,20 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
gamma, log_emd = emd(a_extended, b_extended, M_extended, log=True,
**kwargs)
+
+ gamma = nx.from_numpy(gamma[:len(a), :len(b)], type_as=M)
+
if log_emd['warning'] is not None:
raise ValueError("Error in the EMD resolution: try to increase the"
" number of dummy points")
- log_emd['partial_w_dist'] = np.sum(M * gamma[:len(a), :len(b)])
+ log_emd['partial_w_dist'] = nx.sum(M0 * gamma)
+ log_emd['u'] = nx.from_numpy(log_emd['u'][:len(a)], type_as=a0)
+ log_emd['v'] = nx.from_numpy(log_emd['v'][:len(b)], type_as=b0)
if log:
- return gamma[:len(a), :len(b)], log_emd
+ return gamma, log_emd
else:
- return gamma[:len(a), :len(b)]
+ return gamma
def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
diff --git a/ot/smooth.py b/ot/smooth.py
index 6855005..8e0ef38 100644
--- a/ot/smooth.py
+++ b/ot/smooth.py
@@ -44,6 +44,7 @@ Original code from https://github.com/mblondel/smooth-ot/
import numpy as np
from scipy.optimize import minimize
+from .backend import get_backend
def projection_simplex(V, z=1, axis=None):
@@ -511,6 +512,8 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
"""
+ nx = get_backend(a, b, M)
+
if reg_type.lower() in ['l2', 'squaredl2']:
regul = SquaredL2(gamma=reg)
elif reg_type.lower() in ['entropic', 'negentropy', 'kl']:
@@ -518,15 +521,19 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
else:
raise NotImplementedError('Unknown regularization')
+ a0, b0, M0 = a, b, M
+ # convert to humpy
+ a, b, M = nx.to_numpy(a, b, M)
+
# solve dual
alpha, beta, res = solve_dual(a, b, M, regul, max_iter=numItermax,
tol=stopThr, verbose=verbose)
# reconstruct transport matrix
- G = get_plan_from_dual(alpha, beta, M, regul)
+ G = nx.from_numpy(get_plan_from_dual(alpha, beta, M, regul), type_as=M0)
if log:
- log = {'alpha': alpha, 'beta': beta, 'res': res}
+ log = {'alpha': nx.from_numpy(alpha, type_as=a0), 'beta': nx.from_numpy(beta, type_as=b0), 'res': res}
return G, log
else:
return G
diff --git a/ot/solvers.py b/ot/solvers.py
new file mode 100644
index 0000000..0294d71
--- /dev/null
+++ b/ot/solvers.py
@@ -0,0 +1,347 @@
+# -*- coding: utf-8 -*-
+"""
+General OT solvers with unified API
+"""
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+from .utils import OTResult
+from .lp import emd2
+from .backend import get_backend
+from .unbalanced import mm_unbalanced, sinkhorn_knopp_unbalanced, lbfgsb_unbalanced
+from .bregman import sinkhorn_log
+from .partial import partial_wasserstein_lagrange
+from .smooth import smooth_ot_dual
+
+
+def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,
+ unbalanced_type='KL', n_threads=1, max_iter=None, plan_init=None,
+ potentials_init=None, tol=None, verbose=False):
+ r"""Solve the discrete optimal transport problem and return :any:`OTResult` object
+
+ The function solves the following general optimal transport problem
+
+ .. math::
+ \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) +
+ \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) +
+ \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b})
+
+ The regularization is selected with :any:`reg` (:math:`\lambda_r`) and :any:`reg_type`. By
+ default ``reg=None`` and there is no regularization. The unbalanced marginal
+ penalization can be selected with :any:`unbalanced` (:math:`\lambda_u`) and
+ :any:`unbalanced_type`. By default ``unbalanced=None`` and the function
+ solves the exact optimal transport problem (respecting the marginals).
+
+ Parameters
+ ----------
+ M : array_like, shape (dim_a, dim_b)
+ Loss matrix
+ a : array-like, shape (dim_a,), optional
+ Samples weights in the source domain (default is uniform)
+ b : array-like, shape (dim_b,), optional
+ Samples weights in the source domain (default is uniform)
+ reg : float, optional
+ Regularization weight :math:`\lambda_r`, by default None (no reg., exact
+ OT)
+ reg_type : str, optional
+ Type of regularization :math:`R` either "KL", "L2", 'entropy', by default "KL"
+ unbalanced : float, optional
+ Unbalanced penalization weight :math:`\lambda_u`, by default None
+ (balanced OT)
+ unbalanced_type : str, optional
+ Type of unbalanced penalization unction :math:`U` either "KL", "L2", 'TV', by default 'KL'
+ n_threads : int, optional
+ Number of OMP threads for exact OT solver, by default 1
+ max_iter : int, optional
+ Maximum number of iteration, by default None (default values in each solvers)
+ plan_init : array_like, shape (dim_a, dim_b), optional
+ Initialization of the OT plan for iterative methods, by default None
+ potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional
+ Initialization of the OT dual potentials for iterative methods, by default None
+ tol : _type_, optional
+ Tolerance for solution precision, by default None (default values in each solvers)
+ verbose : bool, optional
+ Print information in the solver, by default False
+
+ Returns
+ -------
+ res : OTResult()
+ Result of the optimization problem. The information can be obtained as follows:
+
+ - res.plan : OT plan :math:`\mathbf{T}`
+ - res.potentials : OT dual potentials
+ - res.value : Optimal value of the optimization problem
+ - res.value_linear : Linear OT loss with the optimal OT plan
+
+ See :any:`OTResult` for more information.
+
+ Notes
+ -----
+
+ The following methods are available for solving the OT problems:
+
+ - **Classical exact OT problem** (default parameters):
+
+ .. math::
+ \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F
+
+ s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a}
+
+ \mathbf{T}^T \mathbf{1} = \mathbf{b}
+
+ \mathbf{T} \geq 0
+
+ can be solved with the following code:
+
+ .. code-block:: python
+
+ res = ot.solve(M, a, b)
+
+ - **Entropic regularized OT** (when ``reg!=None``):
+
+ .. math::
+ \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T})
+
+ s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a}
+
+ \mathbf{T}^T \mathbf{1} = \mathbf{b}
+
+ \mathbf{T} \geq 0
+
+ can be solved with the following code:
+
+ .. code-block:: python
+
+ # default is ``"KL"`` regularization (``reg_type="KL"``)
+ res = ot.solve(M, a, b, reg=1.0)
+ # or for original Sinkhorn paper formulation [2]
+ res = ot.solve(M, a, b, reg=1.0, reg_type='entropy')
+
+ - **Quadratic regularized OT** (when ``reg!=None`` and ``reg_type="L2"``):
+
+ .. math::
+ \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T})
+
+ s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a}
+
+ \mathbf{T}^T \mathbf{1} = \mathbf{b}
+
+ \mathbf{T} \geq 0
+
+ can be solved with the following code:
+
+ .. code-block:: python
+
+ res = ot.solve(M,a,b,reg=1.0,reg_type='L2')
+
+ - **Unbalanced OT** (when ``unbalanced!=None``):
+
+ .. math::
+ \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b})
+
+ can be solved with the following code:
+
+ .. code-block:: python
+
+ # default is ``"KL"``
+ res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0)
+ # quadratic unbalanced OT
+ res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0,unbalanced_type='L2')
+ # TV = partial OT
+ res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0,unbalanced_type='TV')
+
+
+ - **Regularized unbalanced regularized OT** (when ``unbalanced!=None`` and ``reg!=None``):
+
+ .. math::
+ \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b})
+
+ can be solved with the following code:
+
+ .. code-block:: python
+
+ # default is ``"KL"`` for both
+ res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0)
+ # quadratic unbalanced OT with KL regularization
+ res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0,unbalanced_type='L2')
+ # both quadratic
+ res = ot.solve(M,a,b,reg=1.0, reg_type='L2',unbalanced=1.0,unbalanced_type='L2')
+
+
+ .. _references-solve:
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation
+ of Optimal Transport, Advances in Neural Information Processing
+ Systems (NIPS) 26, 2013
+
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems.
+ arXiv preprint arXiv:1607.05816.
+
+ .. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse
+ Optimal Transport. Proceedings of the Twenty-First International
+ Conference on Artificial Intelligence and Statistics (AISTATS).
+
+ .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé,
+ A., & Peyré, G. (2019, April). Interpolating between optimal transport
+ and MMD using Sinkhorn divergences. In The 22nd International Conference
+ on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR.
+
+ """
+
+ # detect backend
+ arr = [M]
+ if a is not None:
+ arr.append(a)
+ if b is not None:
+ arr.append(b)
+ nx = get_backend(*arr)
+
+ # create uniform weights if not given
+ if a is None:
+ a = nx.ones(M.shape[0], type_as=M) / M.shape[0]
+ if b is None:
+ b = nx.ones(M.shape[1], type_as=M) / M.shape[1]
+
+ # default values for solutions
+ potentials = None
+ value = None
+ value_linear = None
+ plan = None
+ status = None
+
+ if reg is None or reg == 0: # exact OT
+
+ if unbalanced is None: # Exact balanced OT
+
+ # default values for EMD solver
+ if max_iter is None:
+ max_iter = 1000000
+
+ value_linear, log = emd2(a, b, M, numItermax=max_iter, log=True, return_matrix=True, numThreads=n_threads)
+
+ value = value_linear
+ potentials = (log['u'], log['v'])
+ plan = log['G']
+ status = log["warning"] if log["warning"] is not None else 'Converged'
+
+ elif unbalanced_type.lower() in ['kl', 'l2']: # unbalanced exact OT
+
+ # default values for exact unbalanced OT
+ if max_iter is None:
+ max_iter = 1000
+ if tol is None:
+ tol = 1e-12
+
+ plan, log = mm_unbalanced(a, b, M, reg_m=unbalanced,
+ div=unbalanced_type.lower(), numItermax=max_iter,
+ stopThr=tol, log=True,
+ verbose=verbose, G0=plan_init)
+
+ value_linear = log['cost']
+
+ if unbalanced_type.lower() == 'kl':
+ value = value_linear + unbalanced * (nx.kl_div(nx.sum(plan, 1), a) + nx.kl_div(nx.sum(plan, 0), b))
+ else:
+ err_a = nx.sum(plan, 1) - a
+ err_b = nx.sum(plan, 0) - b
+ value = value_linear + unbalanced * nx.sum(err_a**2) + unbalanced * nx.sum(err_b**2)
+
+ elif unbalanced_type.lower() == 'tv':
+
+ if max_iter is None:
+ max_iter = 1000000
+
+ plan, log = partial_wasserstein_lagrange(a, b, M, reg_m=unbalanced**2, log=True, numItermax=max_iter)
+
+ value_linear = nx.sum(M * plan)
+ err_a = nx.sum(plan, 1) - a
+ err_b = nx.sum(plan, 0) - b
+ value = value_linear + nx.sqrt(unbalanced**2 / 2.0 * (nx.sum(nx.abs(err_a)) +
+ nx.sum(nx.abs(err_b))))
+
+ else:
+ raise (NotImplementedError('Unknown unbalanced_type="{}"'.format(unbalanced_type)))
+
+ else: # regularized OT
+
+ if unbalanced is None: # Balanced regularized OT
+
+ if reg_type.lower() in ['entropy', 'kl']:
+
+ # default values for sinkhorn
+ if max_iter is None:
+ max_iter = 1000
+ if tol is None:
+ tol = 1e-9
+
+ plan, log = sinkhorn_log(a, b, M, reg=reg, numItermax=max_iter,
+ stopThr=tol, log=True,
+ verbose=verbose)
+
+ value_linear = nx.sum(M * plan)
+
+ if reg_type.lower() == 'entropy':
+ value = value_linear + reg * nx.sum(plan * nx.log(plan + 1e-16))
+ else:
+ value = value_linear + reg * nx.kl_div(plan, a[:, None] * b[None, :])
+
+ potentials = (log['log_u'], log['log_v'])
+
+ elif reg_type.lower() == 'l2':
+
+ if max_iter is None:
+ max_iter = 1000
+ if tol is None:
+ tol = 1e-9
+
+ plan, log = smooth_ot_dual(a, b, M, reg=reg, numItermax=max_iter, stopThr=tol, log=True, verbose=verbose)
+
+ value_linear = nx.sum(M * plan)
+ value = value_linear + reg * nx.sum(plan**2)
+ potentials = (log['alpha'], log['beta'])
+
+ else:
+ raise (NotImplementedError('Not implemented reg_type="{}"'.format(reg_type)))
+
+ else: # unbalanced AND regularized OT
+
+ if reg_type.lower() in ['kl'] and unbalanced_type.lower() == 'kl':
+
+ if max_iter is None:
+ max_iter = 1000
+ if tol is None:
+ tol = 1e-9
+
+ plan, log = sinkhorn_knopp_unbalanced(a, b, M, reg=reg, reg_m=unbalanced, numItermax=max_iter, stopThr=tol, verbose=verbose, log=True)
+
+ value_linear = nx.sum(M * plan)
+
+ value = value_linear + reg * nx.kl_div(plan, a[:, None] * b[None, :]) + unbalanced * (nx.kl_div(nx.sum(plan, 1), a) + nx.kl_div(nx.sum(plan, 0), b))
+
+ potentials = (log['logu'], log['logv'])
+
+ elif reg_type.lower() in ['kl', 'l2', 'entropy'] and unbalanced_type.lower() in ['kl', 'l2']:
+
+ if max_iter is None:
+ max_iter = 1000
+ if tol is None:
+ tol = 1e-12
+
+ plan, log = lbfgsb_unbalanced(a, b, M, reg=reg, reg_m=unbalanced, reg_div=reg_type.lower(), regm_div=unbalanced_type.lower(), numItermax=max_iter, stopThr=tol, verbose=verbose, log=True)
+
+ value_linear = nx.sum(M * plan)
+
+ value = log['loss']
+
+ else:
+ raise (NotImplementedError('Not implemented reg_type="{}" and unbalanced_type="{}"'.format(reg_type, unbalanced_type)))
+
+ res = OTResult(potentials=potentials, value=value,
+ value_linear=value_linear, plan=plan, status=status, backend=nx)
+
+ return res
diff --git a/ot/unbalanced.py b/ot/unbalanced.py
index dd9a36e..a71a0dd 100644
--- a/ot/unbalanced.py
+++ b/ot/unbalanced.py
@@ -10,6 +10,9 @@ Regularized Unbalanced OT solvers
from __future__ import division
import warnings
+import numpy as np
+from scipy.optimize import minimize, Bounds
+
from .backend import get_backend
from .utils import list_to_array
# from .utils import unif, dist
@@ -269,7 +272,8 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
stopThr=1e-6, verbose=False, log=False, **kwargs):
r"""
- Solve the entropic regularization unbalanced optimal transport problem and return the loss
+ Solve the entropic regularization unbalanced optimal transport problem and
+ return the OT plan
The function solves the following optimization problem:
@@ -734,7 +738,7 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
if weights is None:
weights = nx.ones(n_hists, type_as=A) / n_hists
else:
- assert len(weights) == A.shape[1]
+ assert (len(weights) == A.shape[1])
if log:
log = {'err': []}
@@ -882,7 +886,7 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
if weights is None:
weights = nx.ones(n_hists, type_as=A) / n_hists
else:
- assert len(weights) == A.shape[1]
+ assert (len(weights) == A.shape[1])
if log:
log = {'err': []}
@@ -1252,3 +1256,182 @@ def mm_unbalanced2(a, b, M, reg_m, div='kl', G0=None, numItermax=1000,
return log_mm['cost'], log_mm
else:
return log_mm['cost']
+
+
+def _get_loss_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl'):
+ """
+ return the loss function (scipy.optimize compatible) for regularized
+ unbalanced OT
+ """
+
+ m, n = M.shape
+
+ def kl(p, q):
+ return np.sum(p * np.log(p / q + 1e-16))
+
+ def reg_l2(G):
+ return np.sum((G - a[:, None] * b[None, :])**2) / 2
+
+ def grad_l2(G):
+ return G - a[:, None] * b[None, :]
+
+ def reg_kl(G):
+ return kl(G, a[:, None] * b[None, :])
+
+ def grad_kl(G):
+ return np.log(G / (a[:, None] * b[None, :]) + 1e-16) + 1
+
+ def reg_entropy(G):
+ return kl(G, 1)
+
+ def grad_entropy(G):
+ return np.log(G + 1e-16) + 1
+
+ if reg_div == 'kl':
+ reg_fun = reg_kl
+ grad_reg_fun = grad_kl
+ elif reg_div == 'entropy':
+ reg_fun = reg_entropy
+ grad_reg_fun = grad_entropy
+ else:
+ reg_fun = reg_l2
+ grad_reg_fun = grad_l2
+
+ def marg_l2(G):
+ return 0.5 * np.sum((G.sum(1) - a)**2) + 0.5 * np.sum((G.sum(0) - b)**2)
+
+ def grad_marg_l2(G):
+ return np.outer((G.sum(1) - a), np.ones(n)) + np.outer(np.ones(m), (G.sum(0) - b))
+
+ def marg_kl(G):
+ return kl(G.sum(1), a) + kl(G.sum(0), b)
+
+ def grad_marg_kl(G):
+ return np.outer(np.log(G.sum(1) / a + 1e-16) + 1, np.ones(n)) + np.outer(np.ones(m), np.log(G.sum(0) / b + 1e-16) + 1)
+
+ if regm_div == 'kl':
+ regm_fun = marg_kl
+ grad_regm_fun = grad_marg_kl
+ else:
+ regm_fun = marg_l2
+ grad_regm_fun = grad_marg_l2
+
+ def _func(G):
+ G = G.reshape((m, n))
+
+ # compute loss
+ val = np.sum(G * M) + reg * reg_fun(G) + reg_m * regm_fun(G)
+
+ # compute gradient
+ grad = M + reg * grad_reg_fun(G) + reg_m * grad_regm_fun(G)
+
+ return val, grad.ravel()
+
+ return _func
+
+
+def lbfgsb_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl', G0=None, numItermax=1000,
+ stopThr=1e-15, method='L-BFGS-B', verbose=False, log=False):
+ r"""
+ Solve the unbalanced optimal transport problem and return the OT plan using L-BFGS-B.
+ The function solves the following optimization problem:
+
+ .. math::
+ W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ + \mathrm{reg} \mathrm{div}(\gamma,\mathbf{a}\mathbf{b}^T)
+ \mathrm{reg_m} \cdot \mathrm{div_m}(\gamma \mathbf{1}, \mathbf{a}) +
+ \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b})
+
+ s.t.
+ \gamma \geq 0
+
+ where:
+
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
+ unbalanced distributions
+ - :math:`\mathrm{div}` is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence
+
+ The algorithm used for solving the problem is a L-BFGS-B from scipy.optimize
+
+ Parameters
+ ----------
+ a : array-like (dim_a,)
+ Unnormalized histogram of dimension `dim_a`
+ b : array-like (dim_b,)
+ Unnormalized histogram of dimension `dim_b`
+ M : array-like (dim_a, dim_b)
+ loss matrix
+ reg: float
+ regularization term (>=0)
+ reg_m: float
+ Marginal relaxation term >= 0
+ reg_div: string, optional
+ Divergence used for regularization.
+ Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic)
+ reg_div: string, optional
+ Divergence to quantify the difference between the marginals.
+ Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic)
+ G0: array-like (dim_a, dim_b)
+ Initialization of the transport matrix
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on error (> 0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+ Returns
+ -------
+ ot_distance : array-like
+ the OT distance between :math:`\mathbf{a}` and :math:`\mathbf{b}`
+ log : dict
+ log dictionary returned only if `log` is `True`
+
+ Examples
+ --------
+ >>> import ot
+ >>> import numpy as np
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> M=[[1., 36.],[9., 4.]]
+ >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 1, 'l2'),2)
+ 0.25
+ >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 1, 'kl'),2)
+ 0.57
+
+ References
+ ----------
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ Unbalanced optimal transport through non-negative penalized
+ linear regression. NeurIPS.
+ See Also
+ --------
+ ot.lp.emd2 : Unregularized OT loss
+ ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss
+ """
+ nx = get_backend(M, a, b)
+
+ M0 = M
+ # convert to humpy
+ a, b, M = nx.to_numpy(a, b, M)
+
+ if G0 is not None:
+ G0 = nx.to_numpy(G0)
+ else:
+ G0 = np.zeros(M.shape)
+
+ _func = _get_loss_unbalanced(a, b, M, reg, reg_m, reg_div, regm_div)
+
+ res = minimize(_func, G0.ravel(), method=method, jac=True, bounds=Bounds(0, np.inf),
+ tol=stopThr, options=dict(maxiter=numItermax, disp=verbose))
+
+ G = nx.from_numpy(res.x.reshape(M.shape), type_as=M0)
+
+ if log:
+ log = {'loss': nx.from_numpy(res.fun, type_as=M0), 'res': res}
+ return G, log
+ else:
+ return G
diff --git a/ot/utils.py b/ot/utils.py
index e3437da..9093f09 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -15,7 +15,7 @@ from scipy.spatial.distance import cdist
import sys
import warnings
from inspect import signature
-from .backend import get_backend, Backend
+from .backend import get_backend, Backend, NumpyBackend
__time_tic_toc = time.time()
@@ -611,3 +611,203 @@ class UndefinedParameter(Exception):
"""
pass
+
+
+class OTResult:
+ def __init__(self, potentials=None, value=None, value_linear=None, plan=None, log=None, backend=None, sparse_plan=None, lazy_plan=None, status=None):
+
+ self._potentials = potentials
+ self._value = value
+ self._value_linear = value_linear
+ self._plan = plan
+ self._log = log
+ self._sparse_plan = sparse_plan
+ self._lazy_plan = lazy_plan
+ self._backend = backend if backend is not None else NumpyBackend()
+ self._status = status
+
+ # I assume that other solvers may return directly
+ # some primal objects?
+ # In the code below, let's define the main quantities
+ # that may be of interest to users.
+ # An OT solver returns an object that inherits from OTResult
+ # (e.g. SinkhornOTResult) and implements the relevant
+ # methods (e.g. "plan" and "lazy_plan" but not "sparse_plan", etc.).
+ # log is a dictionary containing potential information about the solver
+
+ # Dual potentials --------------------------------------------
+
+ def __repr__(self):
+ s = 'OTResult('
+ if self._value is not None:
+ s += 'value={},'.format(self._value)
+ if self._value_linear is not None:
+ s += 'value_linear={},'.format(self._value_linear)
+ if self._plan is not None:
+ s += 'plan={}(shape={}),'.format(self._plan.__class__.__name__, self._plan.shape)
+
+ if s[-1] != '(':
+ s = s[:-1] + ')'
+ else:
+ s = s + ')'
+ return s
+
+ @property
+ def potentials(self):
+ """Dual potentials, i.e. Lagrange multipliers for the marginal constraints.
+
+ This pair of arrays has the same shape, numerical type
+ and properties as the input weights "a" and "b".
+ """
+ if self._potentials is not None:
+ return self._potentials
+ else:
+ raise NotImplementedError()
+
+ @property
+ def potential_a(self):
+ """First dual potential, associated to the "source" measure "a"."""
+ if self._potentials is not None:
+ return self._potentials[0]
+ else:
+ raise NotImplementedError()
+
+ @property
+ def potential_b(self):
+ """Second dual potential, associated to the "target" measure "b"."""
+ if self._potentials is not None:
+ return self._potentials[1]
+ else:
+ raise NotImplementedError()
+
+ # Transport plan -------------------------------------------
+ @property
+ def plan(self):
+ """Transport plan, encoded as a dense array."""
+ # N.B.: We may catch out-of-memory errors and suggest
+ # the use of lazy_plan or sparse_plan when appropriate.
+
+ if self._plan is not None:
+ return self._plan
+ else:
+ raise NotImplementedError()
+
+ @property
+ def sparse_plan(self):
+ """Transport plan, encoded as a sparse array."""
+ if self._sparse_plan is not None:
+ return self._sparse_plan
+ elif self._plan is not None:
+ return self._backend.tocsr(self._plan)
+ else:
+ raise NotImplementedError()
+
+ @property
+ def lazy_plan(self):
+ """Transport plan, encoded as a symbolic KeOps LazyTensor."""
+ raise NotImplementedError()
+
+ # Loss values --------------------------------
+
+ @property
+ def value(self):
+ """Full transport cost, including possible regularization terms."""
+ if self._value is not None:
+ return self._value
+ else:
+ raise NotImplementedError()
+
+ @property
+ def value_linear(self):
+ """The "minimal" transport cost, i.e. the product between the transport plan and the cost."""
+ if self._value_linear is not None:
+ return self._value_linear
+ else:
+ raise NotImplementedError()
+
+ # Marginal constraints -------------------------
+ @property
+ def marginals(self):
+ """Marginals of the transport plan: should be very close to "a" and "b"
+ for balanced OT."""
+ if self._plan is not None:
+ return self.marginal_a, self.marginal_b
+ else:
+ raise NotImplementedError()
+
+ @property
+ def marginal_a(self):
+ """First marginal of the transport plan, with the same shape as "a"."""
+ if self._plan is not None:
+ return self._backend.sum(self._plan, 1)
+ else:
+ raise NotImplementedError()
+
+ @property
+ def marginal_b(self):
+ """Second marginal of the transport plan, with the same shape as "b"."""
+ if self._plan is not None:
+ return self._backend.sum(self._plan, 0)
+ else:
+ raise NotImplementedError()
+
+ @property
+ def status(self):
+ """Optimization status of the solver."""
+ if self._status is not None:
+ return self._status
+ else:
+ raise NotImplementedError()
+
+ # Barycentric mappings -------------------------
+ # Return the displacement vectors as an array
+ # that has the same shape as "xa"/"xb" (for samples)
+ # or "a"/"b" * D (for images)?
+
+ @property
+ def a_to_b(self):
+ """Displacement vectors from the first to the second measure."""
+ raise NotImplementedError()
+
+ @property
+ def b_to_a(self):
+ """Displacement vectors from the second to the first measure."""
+ raise NotImplementedError()
+
+ # # Wasserstein barycenters ----------------------
+ # @property
+ # def masses(self):
+ # """Masses for the Wasserstein barycenter."""
+ # raise NotImplementedError()
+
+ # @property
+ # def samples(self):
+ # """Sample locations for the Wasserstein barycenter."""
+ # raise NotImplementedError()
+
+ # Miscellaneous --------------------------------
+
+ @property
+ def citation(self):
+ """Appropriate citation(s) for this result, in plain text and BibTex formats."""
+
+ # The string below refers to the POT library:
+ # successor methods may concatenate the relevant references
+ # to the original definitions, solvers and underlying numerical backends.
+ return """POT library:
+
+ POT Python Optimal Transport library, Journal of Machine Learning Research, 22(78):1−8, 2021.
+ Website: https://pythonot.github.io/
+ Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer;
+
+ @article{flamary2021pot,
+ author = {R{\'e}mi Flamary and Nicolas Courty and Alexandre Gramfort and Mokhtar Z. Alaya and Aur{\'e}lie Boisbunon and Stanislas Chambon and Laetitia Chapel and Adrien Corenflos and Kilian Fatras and Nemo Fournier and L{\'e}o Gautheron and Nathalie T.H. Gayraud and Hicham Janati and Alain Rakotomamonjy and Ievgen Redko and Antoine Rolet and Antony Schutz and Vivien Seguy and Danica J. Sutherland and Romain Tavenard and Alexander Tong and Titouan Vayer},
+ title = {{POT}: {Python} {Optimal} {Transport}},
+ journal = {Journal of Machine Learning Research},
+ year = {2021},
+ volume = {22},
+ number = {78},
+ pages = {1-8},
+ url = {http://jmlr.org/papers/v22/20-451.html}
+ }
+ """
diff --git a/test/test_backend.py b/test/test_backend.py
index 311c075..3628f61 100644
--- a/test/test_backend.py
+++ b/test/test_backend.py
@@ -275,6 +275,8 @@ def test_empty_backend():
with pytest.raises(NotImplementedError):
nx.sqrtm(M)
with pytest.raises(NotImplementedError):
+ nx.kl_div(M, M)
+ with pytest.raises(NotImplementedError):
nx.isfinite(M)
with pytest.raises(NotImplementedError):
nx.array_equal(M, M)
@@ -592,6 +594,10 @@ def test_func_backends(nx):
lst_b.append(nx.to_numpy(A))
lst_name.append("matrix square root")
+ A = nx.kl_div(nx.abs(Mb), nx.abs(Mb) + 1)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("Kullback-Leibler divergence")
+
A = nx.concatenate([vb, nx.from_numpy(np.array([np.inf, np.nan]))], axis=0)
A = nx.isfinite(A)
lst_b.append(nx.to_numpy(A))
diff --git a/test/test_partial.py b/test/test_partial.py
index 33fc259..ae4a1ab 100755
--- a/test/test_partial.py
+++ b/test/test_partial.py
@@ -79,6 +79,8 @@ def test_partial_wasserstein_lagrange():
w0, log0 = ot.partial.partial_wasserstein_lagrange(p, q, M, 1, log=True)
+ w0, log0 = ot.partial.partial_wasserstein_lagrange(p, q, M, 100, log=True)
+
def test_partial_wasserstein():
diff --git a/test/test_solvers.py b/test/test_solvers.py
new file mode 100644
index 0000000..b792aca
--- /dev/null
+++ b/test/test_solvers.py
@@ -0,0 +1,133 @@
+"""Tests for ot solvers"""
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+
+import itertools
+import numpy as np
+import pytest
+
+import ot
+
+
+lst_reg = [None, 1.0]
+lst_reg_type = ['KL', 'entropy', 'L2']
+lst_unbalanced = [None, 0.9]
+lst_unbalanced_type = ['KL', 'L2', 'TV']
+
+
+def assert_allclose_sol(sol1, sol2):
+
+ lst_attr = ['value', 'value_linear', 'plan',
+ 'potential_a', 'potential_b', 'marginal_a', 'marginal_b']
+
+ nx1 = sol1._backend if sol1._backend is not None else ot.backend.NumpyBackend()
+ nx2 = sol2._backend if sol2._backend is not None else ot.backend.NumpyBackend()
+
+ for attr in lst_attr:
+ try:
+ np.allclose(nx1.to_numpy(getattr(sol1, attr)), nx2.to_numpy(getattr(sol2, attr)))
+ except NotImplementedError:
+ pass
+
+
+def test_solve(nx):
+ n_samples_s = 10
+ n_samples_t = 7
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples_s, n_features)
+ y = rng.randn(n_samples_t, n_features)
+ a = ot.utils.unif(n_samples_s)
+ b = ot.utils.unif(n_samples_t)
+
+ M = ot.dist(x, y)
+
+ # solve unif weights
+ sol0 = ot.solve(M)
+
+ print(sol0)
+
+ # solve signe weights
+ sol = ot.solve(M, a, b)
+
+ # check some attributes
+ sol.potentials
+ sol.sparse_plan
+ sol.marginals
+ sol.status
+
+ assert_allclose_sol(sol0, sol)
+
+ # solve in backend
+ ab, bb, Mb = nx.from_numpy(a, b, M)
+ solb = ot.solve(M, a, b)
+
+ assert_allclose_sol(sol, solb)
+
+ # test not implemented unbalanced and check raise
+ with pytest.raises(NotImplementedError):
+ sol0 = ot.solve(M, unbalanced=1, unbalanced_type='cryptic divergence')
+
+ # test not implemented reg_type and check raise
+ with pytest.raises(NotImplementedError):
+ sol0 = ot.solve(M, reg=1, reg_type='cryptic divergence')
+
+
+@pytest.mark.parametrize("reg,reg_type,unbalanced,unbalanced_type", itertools.product(lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type))
+def test_solve_grid(nx, reg, reg_type, unbalanced, unbalanced_type):
+ n_samples_s = 10
+ n_samples_t = 7
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples_s, n_features)
+ y = rng.randn(n_samples_t, n_features)
+ a = ot.utils.unif(n_samples_s)
+ b = ot.utils.unif(n_samples_t)
+
+ M = ot.dist(x, y)
+
+ try:
+
+ # solve unif weights
+ sol0 = ot.solve(M, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type)
+
+ # solve signe weights
+ sol = ot.solve(M, a, b, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type)
+
+ assert_allclose_sol(sol0, sol)
+
+ # solve in backend
+ ab, bb, Mb = nx.from_numpy(a, b, M)
+ solb = ot.solve(M, a, b, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type)
+
+ assert_allclose_sol(sol, solb)
+ except NotImplementedError:
+ pass
+
+
+def test_solve_not_implemented(nx):
+
+ n_samples_s = 10
+ n_samples_t = 7
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples_s, n_features)
+ y = rng.randn(n_samples_t, n_features)
+
+ M = ot.dist(x, y)
+
+ # test not implemented and check raise
+ with pytest.raises(NotImplementedError):
+ ot.solve(M, reg=1.0, reg_type='cryptic divergence')
+ with pytest.raises(NotImplementedError):
+ ot.solve(M, unbalanced=1.0, unbalanced_type='cryptic divergence')
+
+ # pairs of incompatible divergences
+ with pytest.raises(NotImplementedError):
+ ot.solve(M, reg=1.0, reg_type='kl', unbalanced=1.0, unbalanced_type='tv')
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py
index fc40df0..b76d738 100644
--- a/test/test_unbalanced.py
+++ b/test/test_unbalanced.py
@@ -5,6 +5,7 @@
#
# License: MIT License
+import itertools
import numpy as np
import ot
import pytest
@@ -289,6 +290,28 @@ def test_implemented_methods(nx):
method=method)
+@pytest.mark.parametrize("reg_div,regm_div", itertools.product(['kl', 'l2', 'entropy'], ['kl', 'l2']))
+def test_lbfgsb_unbalanced(nx, reg_div, regm_div):
+
+ np.random.seed(42)
+
+ xs = np.random.randn(5, 2)
+ xt = np.random.randn(6, 2)
+
+ M = ot.dist(xs, xt)
+
+ a = ot.unif(5)
+ b = ot.unif(6)
+
+ G, log = ot.unbalanced.lbfgsb_unbalanced(a, b, M, 1, 10, reg_div=reg_div, regm_div=regm_div, log=True, verbose=False)
+
+ ab, bb, Mb = nx.from_numpy(a, b, M)
+
+ Gb, log = ot.unbalanced.lbfgsb_unbalanced(ab, bb, Mb, 1, 10, reg_div=reg_div, regm_div=regm_div, log=True, verbose=False)
+
+ np.testing.assert_allclose(G, nx.to_numpy(Gb))
+
+
def test_mm_convergence(nx):
n = 100
rng = np.random.RandomState(42)
diff --git a/test/test_utils.py b/test/test_utils.py
index 19b6365..666c157 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -301,3 +301,32 @@ def test_BaseEstimator():
cl.set_params(bibi=10)
assert cl.first == 'spam again'
+
+
+def test_OTResult():
+
+ res = ot.utils.OTResult()
+
+ # test print
+ print(res)
+
+ # tets get citation
+ print(res.citation)
+
+ lst_attributes = ['a_to_b',
+ 'b_to_a',
+ 'lazy_plan',
+ 'marginal_a',
+ 'marginal_b',
+ 'marginals',
+ 'plan',
+ 'potential_a',
+ 'potential_b',
+ 'potentials',
+ 'sparse_plan',
+ 'status',
+ 'value',
+ 'value_linear']
+ for at in lst_attributes:
+ with pytest.raises(NotImplementedError):
+ getattr(res, at)