summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com>2021-10-25 17:35:36 +0200
committerGitHub <noreply@github.com>2021-10-25 17:35:36 +0200
commit76450dddf8dd62b9714b72e99ae075516246d433 (patch)
tree67de8de1c185cc8e7fc33a1fc0613015824d1fbb
parent7a65086dd340265d0223eb8ffb5c9a5152a82dff (diff)
[MRG] Backend for optim (#282)
* Backend for optim * Bug solve * Doc update * backend tests now with fixture * Unused imports removed * Docs * Docs * Docs * Outer product backend docs * Prettier docs * Pep8 * Mistakes corrected Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
-rw-r--r--ot/backend.py118
-rw-r--r--ot/lp/__init__.py4
-rw-r--r--ot/optim.py155
-rw-r--r--test/test_backend.py22
-rw-r--r--test/test_optim.py78
-rw-r--r--test/test_ot.py6
-rw-r--r--test/test_utils.py7
7 files changed, 250 insertions, 140 deletions
diff --git a/ot/backend.py b/ot/backend.py
index a4a4757..876b96a 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -123,7 +123,7 @@ class Backend():
r"""
Creates a tensor full of zeros.
- This function follow the api from :any:`numpy.zeros`
+ This function follows the api from :any:`numpy.zeros`
See: https://numpy.org/doc/stable/reference/generated/numpy.zeros.html
"""
@@ -133,7 +133,7 @@ class Backend():
r"""
Creates a tensor full of ones.
- This function follow the api from :any:`numpy.ones`
+ This function follows the api from :any:`numpy.ones`
See: https://numpy.org/doc/stable/reference/generated/numpy.ones.html
"""
@@ -143,7 +143,7 @@ class Backend():
r"""
Returns evenly spaced values within a given interval.
- This function follow the api from :any:`numpy.arange`
+ This function follows the api from :any:`numpy.arange`
See: https://numpy.org/doc/stable/reference/generated/numpy.arange.html
"""
@@ -153,7 +153,7 @@ class Backend():
r"""
Creates a tensor with given shape, filled with given value.
- This function follow the api from :any:`numpy.full`
+ This function follows the api from :any:`numpy.full`
See: https://numpy.org/doc/stable/reference/generated/numpy.full.html
"""
@@ -163,7 +163,7 @@ class Backend():
r"""
Creates the identity matrix of given size.
- This function follow the api from :any:`numpy.eye`
+ This function follows the api from :any:`numpy.eye`
See: https://numpy.org/doc/stable/reference/generated/numpy.eye.html
"""
@@ -173,7 +173,7 @@ class Backend():
r"""
Sums tensor elements over given dimensions.
- This function follow the api from :any:`numpy.sum`
+ This function follows the api from :any:`numpy.sum`
See: https://numpy.org/doc/stable/reference/generated/numpy.sum.html
"""
@@ -183,7 +183,7 @@ class Backend():
r"""
Returns the cumulative sum of tensor elements over given dimensions.
- This function follow the api from :any:`numpy.cumsum`
+ This function follows the api from :any:`numpy.cumsum`
See: https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html
"""
@@ -193,7 +193,7 @@ class Backend():
r"""
Returns the maximum of an array or maximum along given dimensions.
- This function follow the api from :any:`numpy.amax`
+ This function follows the api from :any:`numpy.amax`
See: https://numpy.org/doc/stable/reference/generated/numpy.amax.html
"""
@@ -203,7 +203,7 @@ class Backend():
r"""
Returns the maximum of an array or maximum along given dimensions.
- This function follow the api from :any:`numpy.amin`
+ This function follows the api from :any:`numpy.amin`
See: https://numpy.org/doc/stable/reference/generated/numpy.amin.html
"""
@@ -213,7 +213,7 @@ class Backend():
r"""
Returns element-wise maximum of array elements.
- This function follow the api from :any:`numpy.maximum`
+ This function follows the api from :any:`numpy.maximum`
See: https://numpy.org/doc/stable/reference/generated/numpy.maximum.html
"""
@@ -223,7 +223,7 @@ class Backend():
r"""
Returns element-wise minimum of array elements.
- This function follow the api from :any:`numpy.minimum`
+ This function follows the api from :any:`numpy.minimum`
See: https://numpy.org/doc/stable/reference/generated/numpy.minimum.html
"""
@@ -233,7 +233,7 @@ class Backend():
r"""
Returns the dot product of two tensors.
- This function follow the api from :any:`numpy.dot`
+ This function follows the api from :any:`numpy.dot`
See: https://numpy.org/doc/stable/reference/generated/numpy.dot.html
"""
@@ -243,7 +243,7 @@ class Backend():
r"""
Computes the absolute value element-wise.
- This function follow the api from :any:`numpy.absolute`
+ This function follows the api from :any:`numpy.absolute`
See: https://numpy.org/doc/stable/reference/generated/numpy.absolute.html
"""
@@ -253,7 +253,7 @@ class Backend():
r"""
Computes the exponential value element-wise.
- This function follow the api from :any:`numpy.exp`
+ This function follows the api from :any:`numpy.exp`
See: https://numpy.org/doc/stable/reference/generated/numpy.exp.html
"""
@@ -263,7 +263,7 @@ class Backend():
r"""
Computes the natural logarithm, element-wise.
- This function follow the api from :any:`numpy.log`
+ This function follows the api from :any:`numpy.log`
See: https://numpy.org/doc/stable/reference/generated/numpy.log.html
"""
@@ -273,7 +273,7 @@ class Backend():
r"""
Returns the non-ngeative square root of a tensor, element-wise.
- This function follow the api from :any:`numpy.sqrt`
+ This function follows the api from :any:`numpy.sqrt`
See: https://numpy.org/doc/stable/reference/generated/numpy.sqrt.html
"""
@@ -283,7 +283,7 @@ class Backend():
r"""
First tensor elements raised to powers from second tensor, element-wise.
- This function follow the api from :any:`numpy.power`
+ This function follows the api from :any:`numpy.power`
See: https://numpy.org/doc/stable/reference/generated/numpy.power.html
"""
@@ -293,7 +293,7 @@ class Backend():
r"""
Computes the matrix frobenius norm.
- This function follow the api from :any:`numpy.linalg.norm`
+ This function follows the api from :any:`numpy.linalg.norm`
See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.norm.html
"""
@@ -303,7 +303,7 @@ class Backend():
r"""
Tests whether any tensor element along given dimensions evaluates to True.
- This function follow the api from :any:`numpy.any`
+ This function follows the api from :any:`numpy.any`
See: https://numpy.org/doc/stable/reference/generated/numpy.any.html
"""
@@ -313,7 +313,7 @@ class Backend():
r"""
Tests element-wise for NaN and returns result as a boolean tensor.
- This function follow the api from :any:`numpy.isnan`
+ This function follows the api from :any:`numpy.isnan`
See: https://numpy.org/doc/stable/reference/generated/numpy.isnan.html
"""
@@ -323,7 +323,7 @@ class Backend():
r"""
Tests element-wise for positive or negative infinity and returns result as a boolean tensor.
- This function follow the api from :any:`numpy.isinf`
+ This function follows the api from :any:`numpy.isinf`
See: https://numpy.org/doc/stable/reference/generated/numpy.isinf.html
"""
@@ -333,7 +333,7 @@ class Backend():
r"""
Evaluates the Einstein summation convention on the operands.
- This function follow the api from :any:`numpy.einsum`
+ This function follows the api from :any:`numpy.einsum`
See: https://numpy.org/doc/stable/reference/generated/numpy.einsum.html
"""
@@ -343,7 +343,7 @@ class Backend():
r"""
Returns a sorted copy of a tensor.
- This function follow the api from :any:`numpy.sort`
+ This function follows the api from :any:`numpy.sort`
See: https://numpy.org/doc/stable/reference/generated/numpy.sort.html
"""
@@ -353,7 +353,7 @@ class Backend():
r"""
Returns the indices that would sort a tensor.
- This function follow the api from :any:`numpy.argsort`
+ This function follows the api from :any:`numpy.argsort`
See: https://numpy.org/doc/stable/reference/generated/numpy.argsort.html
"""
@@ -363,7 +363,7 @@ class Backend():
r"""
Finds indices where elements should be inserted to maintain order in given tensor.
- This function follow the api from :any:`numpy.searchsorted`
+ This function follows the api from :any:`numpy.searchsorted`
See: https://numpy.org/doc/stable/reference/generated/numpy.searchsorted.html
"""
@@ -373,7 +373,7 @@ class Backend():
r"""
Reverses the order of elements in a tensor along given dimensions.
- This function follow the api from :any:`numpy.flip`
+ This function follows the api from :any:`numpy.flip`
See: https://numpy.org/doc/stable/reference/generated/numpy.flip.html
"""
@@ -383,7 +383,7 @@ class Backend():
"""
Limits the values in a tensor.
- This function follow the api from :any:`numpy.clip`
+ This function follows the api from :any:`numpy.clip`
See: https://numpy.org/doc/stable/reference/generated/numpy.clip.html
"""
@@ -393,7 +393,7 @@ class Backend():
r"""
Repeats elements of a tensor.
- This function follow the api from :any:`numpy.repeat`
+ This function follows the api from :any:`numpy.repeat`
See: https://numpy.org/doc/stable/reference/generated/numpy.repeat.html
"""
@@ -403,7 +403,7 @@ class Backend():
r"""
Gathers elements of a tensor along given dimensions.
- This function follow the api from :any:`numpy.take_along_axis`
+ This function follows the api from :any:`numpy.take_along_axis`
See: https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html
"""
@@ -413,7 +413,7 @@ class Backend():
r"""
Joins a sequence of tensors along an existing dimension.
- This function follow the api from :any:`numpy.concatenate`
+ This function follows the api from :any:`numpy.concatenate`
See: https://numpy.org/doc/stable/reference/generated/numpy.concatenate.html
"""
@@ -423,7 +423,7 @@ class Backend():
r"""
Pads a tensor.
- This function follow the api from :any:`numpy.pad`
+ This function follows the api from :any:`numpy.pad`
See: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
"""
@@ -433,7 +433,7 @@ class Backend():
r"""
Returns the indices of the maximum values of a tensor along given dimensions.
- This function follow the api from :any:`numpy.argmax`
+ This function follows the api from :any:`numpy.argmax`
See: https://numpy.org/doc/stable/reference/generated/numpy.argmax.html
"""
@@ -443,7 +443,7 @@ class Backend():
r"""
Computes the arithmetic mean of a tensor along given dimensions.
- This function follow the api from :any:`numpy.mean`
+ This function follows the api from :any:`numpy.mean`
See: https://numpy.org/doc/stable/reference/generated/numpy.mean.html
"""
@@ -453,7 +453,7 @@ class Backend():
r"""
Computes the standard deviation of a tensor along given dimensions.
- This function follow the api from :any:`numpy.std`
+ This function follows the api from :any:`numpy.std`
See: https://numpy.org/doc/stable/reference/generated/numpy.std.html
"""
@@ -463,7 +463,7 @@ class Backend():
r"""
Returns a specified number of evenly spaced values over a given interval.
- This function follow the api from :any:`numpy.linspace`
+ This function follows the api from :any:`numpy.linspace`
See: https://numpy.org/doc/stable/reference/generated/numpy.linspace.html
"""
@@ -473,7 +473,7 @@ class Backend():
r"""
Returns coordinate matrices from coordinate vectors (Numpy convention).
- This function follow the api from :any:`numpy.meshgrid`
+ This function follows the api from :any:`numpy.meshgrid`
See: https://numpy.org/doc/stable/reference/generated/numpy.meshgrid.html
"""
@@ -483,7 +483,7 @@ class Backend():
r"""
Extracts or constructs a diagonal tensor.
- This function follow the api from :any:`numpy.diag`
+ This function follows the api from :any:`numpy.diag`
See: https://numpy.org/doc/stable/reference/generated/numpy.diag.html
"""
@@ -493,7 +493,7 @@ class Backend():
r"""
Finds unique elements of given tensor.
- This function follow the api from :any:`numpy.unique`
+ This function follows the api from :any:`numpy.unique`
See: https://numpy.org/doc/stable/reference/generated/numpy.unique.html
"""
@@ -503,7 +503,7 @@ class Backend():
r"""
Computes the log of the sum of exponentials of input elements.
- This function follow the api from :any:`scipy.special.logsumexp`
+ This function follows the api from :any:`scipy.special.logsumexp`
See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.logsumexp.html
"""
@@ -513,12 +513,32 @@ class Backend():
r"""
Joins a sequence of tensors along a new dimension.
- This function follow the api from :any:`numpy.stack`
+ This function follows the api from :any:`numpy.stack`
See: https://numpy.org/doc/stable/reference/generated/numpy.stack.html
"""
raise NotImplementedError()
+ def outer(self, a, b):
+ r"""
+ Computes the outer product between two vectors.
+
+ This function follows the api from :any:`numpy.outer`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.outer.html
+ """
+ raise NotImplementedError()
+
+ def reshape(self, a, shape):
+ r"""
+ Gives a new shape to a tensor without changing its data.
+
+ This function follows the api from :any:`numpy.reshape`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.reshape.html
+ """
+ raise NotImplementedError()
+
class NumpyBackend(Backend):
"""
@@ -644,6 +664,9 @@ class NumpyBackend(Backend):
def flip(self, a, axis=None):
return np.flip(a, axis)
+ def outer(self, a, b):
+ return np.outer(a, b)
+
def clip(self, a, a_min, a_max):
return np.clip(a, a_min, a_max)
@@ -686,6 +709,9 @@ class NumpyBackend(Backend):
def stack(self, arrays, axis=0):
return np.stack(arrays, axis)
+ def reshape(self, a, shape):
+ return np.reshape(a, shape)
+
class JaxBackend(Backend):
"""
@@ -815,6 +841,9 @@ class JaxBackend(Backend):
def flip(self, a, axis=None):
return jnp.flip(a, axis)
+ def outer(self, a, b):
+ return jnp.outer(a, b)
+
def clip(self, a, a_min, a_max):
return jnp.clip(a, a_min, a_max)
@@ -857,6 +886,9 @@ class JaxBackend(Backend):
def stack(self, arrays, axis=0):
return jnp.stack(arrays, axis)
+ def reshape(self, a, shape):
+ return jnp.reshape(a, shape)
+
class TorchBackend(Backend):
"""
@@ -1035,6 +1067,9 @@ class TorchBackend(Backend):
else:
return torch.flip(a, dims=axis)
+ def outer(self, a, b):
+ return torch.outer(a, b)
+
def clip(self, a, a_min, a_max):
return torch.clamp(a, a_min, a_max)
@@ -1091,3 +1126,6 @@ class TorchBackend(Backend):
def stack(self, arrays, axis=0):
return torch.stack(arrays, dim=axis)
+
+ def reshape(self, a, shape):
+ return torch.reshape(a, shape)
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index b907b10..c6757d1 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -281,12 +281,12 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
a0, b0, M0 = a, b, M
nx = get_backend(M0, a0, b0)
-
+
# convert to numpy
M = nx.to_numpy(M)
a = nx.to_numpy(a)
b = nx.to_numpy(b)
-
+
# ensure float64
a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
diff --git a/ot/optim.py b/ot/optim.py
index 0359343..6822e4e 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -12,6 +12,8 @@ import numpy as np
from scipy.optimize.linesearch import scalar_search_armijo
from .lp import emd
from .bregman import sinkhorn
+from ot.utils import list_to_array
+from .backend import get_backend
# The corresponding scipy function does not work for matrices
@@ -21,25 +23,25 @@ def line_search_armijo(f, xk, pk, gfk, old_fval,
"""
Armijo linesearch function that works with matrices
- find an approximate minimum of f(xk+alpha*pk) that satifies the
+ Find an approximate minimum of :math:`f(x_k + \\alpha \cdot p_k)` that satisfies the
armijo conditions.
Parameters
----------
f : callable
loss function
- xk : ndarray
+ xk : array-like
initial position
- pk : ndarray
+ pk : array-like
descent direction
- gfk : ndarray
- gradient of f at xk
+ gfk : array-like
+ gradient of `f` at :math:`x_k`
old_fval : float
- loss value at xk
+ loss value at :math:`x_k`
args : tuple, optional
- arguments given to f
+ arguments given to `f`
c1 : float, optional
- c1 const in armijo rule (>0)
+ :math:`c_1` const in armijo rule (>0)
alpha0 : float, optional
initial step (>0)
@@ -53,7 +55,13 @@ def line_search_armijo(f, xk, pk, gfk, old_fval,
loss value at step alpha
"""
- xk = np.atleast_1d(xk)
+
+ xk, pk, gfk = list_to_array(xk, pk, gfk)
+ nx = get_backend(xk, pk)
+
+ if len(xk.shape) == 0:
+ xk = nx.reshape(xk, (-1,))
+
fc = [0]
def phi(alpha1):
@@ -65,7 +73,7 @@ def line_search_armijo(f, xk, pk, gfk, old_fval,
else:
phi0 = old_fval
- derphi0 = np.sum(pk * gfk) # Quickfix for matrices
+ derphi0 = nx.sum(pk * gfk) # Quickfix for matrices
alpha, phi1 = scalar_search_armijo(
phi, phi0, derphi0, c1=c1, alpha0=alpha0)
@@ -79,55 +87,64 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val,
armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None):
"""
Solve the linesearch in the FW iterations
+
Parameters
----------
cost : method
Cost in the FW for the linesearch
- G : ndarray, shape(ns,nt)
+ G : array-like, shape(ns,nt)
The transport map at a given iteration of the FW
- deltaG : ndarray (ns,nt)
+ deltaG : array-like (ns,nt)
Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration
- Mi : ndarray (ns,nt)
+ Mi : array-like (ns,nt)
Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost
- f_val : float
- Value of the cost at G
+ f_val : float
+ Value of the cost at `G`
armijo : bool, optional
- If True the steps of the line-search is found via an armijo research. Else closed form is used.
- If there is convergence issues use False.
- C1 : ndarray (ns,ns), optional
+ If True the steps of the line-search is found via an armijo research. Else closed form is used.
+ If there is convergence issues use False.
+ C1 : array-like (ns,ns), optional
Structure matrix in the source domain. Only used and necessary when armijo=False
- C2 : ndarray (nt,nt), optional
+ C2 : array-like (nt,nt), optional
Structure matrix in the target domain. Only used and necessary when armijo=False
reg : float, optional
- Regularization parameter. Only used and necessary when armijo=False
- Gc : ndarray (ns,nt)
+ Regularization parameter. Only used and necessary when armijo=False
+ Gc : array-like (ns,nt)
Optimal map found by linearization in the FW algorithm. Only used and necessary when armijo=False
- constC : ndarray (ns,nt)
- Constant for the gromov cost. See [24]. Only used and necessary when armijo=False
- M : ndarray (ns,nt), optional
+ constC : array-like (ns,nt)
+ Constant for the gromov cost. See :ref:`[24] <references-solve-linesearch>`. Only used and necessary when armijo=False
+ M : array-like (ns,nt), optional
Cost matrix between the features. Only used and necessary when armijo=False
+
Returns
-------
alpha : float
- The optimal step size of the FW
+ The optimal step size of the FW
fc : int
- nb of function call. Useless here
- f_val : float
- The value of the cost for the next iteration
+ nb of function call. Useless here
+ f_val : float
+ The value of the cost for the next iteration
+
+
+ .. _references-solve-linesearch:
References
----------
- .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
- and Courty Nicolas
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain and Courty Nicolas
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
"""
if armijo:
alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val)
else: # requires symetric matrices
- dot1 = np.dot(C1, deltaG)
- dot12 = dot1.dot(C2)
- a = -2 * reg * np.sum(dot12 * deltaG)
- b = np.sum((M + reg * constC) * deltaG) - 2 * reg * (np.sum(dot12 * G) + np.sum(np.dot(C1, G).dot(C2) * deltaG))
+ G, deltaG, C1, C2, constC, M = list_to_array(G, deltaG, C1, C2, constC, M)
+ if isinstance(M, int) or isinstance(M, float):
+ nx = get_backend(G, deltaG, C1, C2, constC)
+ else:
+ nx = get_backend(G, deltaG, C1, C2, constC, M)
+
+ dot = nx.dot(nx.dot(C1, deltaG), C2)
+ a = -2 * reg * nx.sum(dot * deltaG)
+ b = nx.sum((M + reg * constC) * deltaG) - 2 * reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2) * deltaG))
c = cost(G)
alpha = solve_1d_linesearch_quad(a, b, c)
@@ -145,33 +162,33 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg*f(\gamma)
+ \gamma = arg\min_\gamma <\gamma,M>_F + \mathrm{reg} \cdot f(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma 1 = a
\gamma^T 1= b
\gamma\geq 0
where :
- - M is the (ns,nt) metric cost matrix
- - :math:`f` is the regularization term ( and df is its gradient)
- - a and b are source and target weights (sum to 1)
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`f` is the regularization term (and `df` is its gradient)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
- The algorithm used for solving the problem is conditional gradient as discussed in [1]_
+ The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] <references-cg>`
Parameters
----------
- a : ndarray, shape (ns,)
+ a : array-like, shape (ns,)
samples weights in the source domain
- b : ndarray, shape (nt,)
+ b : array-like, shape (nt,)
samples in the target domain
- M : ndarray, shape (ns, nt)
+ M : array-like, shape (ns, nt)
loss matrix
reg : float
Regularization term >0
- G0 : ndarray, shape (ns,nt), optional
+ G0 : array-like, shape (ns,nt), optional
initial guess (default is indep joint density)
numItermax : int, optional
Max number of iterations
@@ -196,6 +213,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
log dictionary return only if log==True in parameters
+ .. _references-cg:
References
----------
@@ -207,6 +225,11 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
ot.bregman.sinkhorn : Entropic regularized optimal transport
"""
+ a, b, M, G0 = list_to_array(a, b, M, G0)
+ if isinstance(M, int) or isinstance(M, float):
+ nx = get_backend(a, b)
+ else:
+ nx = get_backend(a, b, M)
loop = 1
@@ -214,12 +237,12 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
log = {'loss': []}
if G0 is None:
- G = np.outer(a, b)
+ G = nx.outer(a, b)
else:
G = G0
def cost(G):
- return np.sum(M * G) + reg * f(G)
+ return nx.sum(M * G) + reg * f(G)
f_val = cost(G)
if log:
@@ -240,7 +263,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
# problem linearization
Mi = M + reg * df(G)
# set M positive
- Mi += Mi.min()
+ Mi += nx.min(Mi)
# solve linear program
Gc = emd(a, b, Mi, numItermax=numItermaxEmd)
@@ -286,36 +309,36 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg1\cdot\Omega(\gamma) + reg2\cdot f(\gamma)
+ \gamma = arg\min_\gamma <\gamma,M>_F + \mathrm{reg_1}\cdot\Omega(\gamma) + \mathrm{reg_2}\cdot f(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma 1 = a
\gamma^T 1= b
\gamma\geq 0
where :
- - M is the (ns,nt) metric cost matrix
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - :math:`f` is the regularization term ( and df is its gradient)
- - a and b are source and target weights (sum to 1)
+ - :math:`f` is the regularization term (and `df` is its gradient)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
- The algorithm used for solving the problem is the generalized conditional gradient as discussed in [5,7]_
+ The algorithm used for solving the problem is the generalized conditional gradient as discussed in :ref:`[5, 7] <references-gcg>`
Parameters
----------
- a : ndarray, shape (ns,)
+ a : array-like, shape (ns,)
samples weights in the source domain
- b : ndarrayv (nt,)
+ b : array-like, (nt,)
samples in the target domain
- M : ndarray, shape (ns, nt)
+ M : array-like, shape (ns, nt)
loss matrix
reg1 : float
Entropic Regularization term >0
reg2 : float
Second Regularization term >0
- G0 : ndarray, shape (ns, nt), optional
+ G0 : array-like, shape (ns, nt), optional
initial guess (default is indep joint density)
numItermax : int, optional
Max number of iterations
@@ -337,9 +360,13 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
log : dict
log dictionary return only if log==True in parameters
+
+ .. _references-gcg:
References
----------
+
.. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
+
.. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567.
See Also
@@ -347,6 +374,8 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
ot.optim.cg : conditional gradient
"""
+ a, b, M, G0 = list_to_array(a, b, M, G0)
+ nx = get_backend(a, b, M)
loop = 1
@@ -354,12 +383,12 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
log = {'loss': []}
if G0 is None:
- G = np.outer(a, b)
+ G = nx.outer(a, b)
else:
G = G0
def cost(G):
- return np.sum(M * G) + reg1 * np.sum(G * np.log(G)) + reg2 * f(G)
+ return nx.sum(M * G) + reg1 * nx.sum(G * nx.log(G)) + reg2 * f(G)
f_val = cost(G)
if log:
@@ -387,7 +416,7 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
deltaG = Gc - G
# line search
- dcost = Mi + reg1 * (1 + np.log(G)) # ??
+ dcost = Mi + reg1 * (1 + nx.log(G)) # ??
alpha, fc, f_val = line_search_armijo(cost, G, deltaG, dcost, f_val)
G = G + alpha * deltaG
@@ -419,9 +448,11 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
def solve_1d_linesearch_quad(a, b, c):
"""
- For any convex or non-convex 1d quadratic function f, solve on [0,1] the following problem:
+ For any convex or non-convex 1d quadratic function `f`, solve the following problem:
+
.. math::
- \argmin f(x)=a*x^{2}+b*x+c
+
+ arg\min_{0 \leq x \leq 1} f(x) = ax^{2} + bx + c
Parameters
----------
diff --git a/test/test_backend.py b/test/test_backend.py
index 859da5a..5853282 100644
--- a/test/test_backend.py
+++ b/test/test_backend.py
@@ -17,9 +17,6 @@ from numpy.testing import assert_array_almost_equal_nulp
from ot.backend import get_backend, get_backend_list, to_numpy
-backend_list = get_backend_list()
-
-
def test_get_backend_list():
lst = get_backend_list()
@@ -28,7 +25,6 @@ def test_get_backend_list():
assert isinstance(lst[0], ot.backend.NumpyBackend)
-@pytest.mark.parametrize('nx', backend_list)
def test_to_numpy(nx):
v = nx.zeros(10)
@@ -92,7 +88,6 @@ def test_get_backend():
get_backend(A, B2)
-@pytest.mark.parametrize('nx', backend_list)
def test_convert_between_backends(nx):
A = np.zeros((3, 2))
@@ -181,6 +176,8 @@ def test_empty_backend():
with pytest.raises(NotImplementedError):
nx.flip(M)
with pytest.raises(NotImplementedError):
+ nx.outer(v, v)
+ with pytest.raises(NotImplementedError):
nx.clip(M, -1, 1)
with pytest.raises(NotImplementedError):
nx.repeat(M, 0, 1)
@@ -208,10 +205,11 @@ def test_empty_backend():
nx.logsumexp(M)
with pytest.raises(NotImplementedError):
nx.stack([M, M])
+ with pytest.raises(NotImplementedError):
+ nx.reshape(M, (5, 3, 2))
-@pytest.mark.parametrize('backend', backend_list)
-def test_func_backends(backend):
+def test_func_backends(nx):
rnd = np.random.RandomState(0)
M = rnd.randn(10, 3)
@@ -220,7 +218,7 @@ def test_func_backends(backend):
lst_tot = []
- for nx in [ot.backend.NumpyBackend(), backend]:
+ for nx in [ot.backend.NumpyBackend(), nx]:
print('Backend: ', nx.__name__)
@@ -371,6 +369,10 @@ def test_func_backends(backend):
lst_b.append(nx.to_numpy(A))
lst_name.append('flip')
+ A = nx.outer(vb, vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('outer')
+
A = nx.clip(vb, 0, 1)
lst_b.append(nx.to_numpy(A))
lst_name.append('clip')
@@ -432,6 +434,10 @@ def test_func_backends(backend):
lst_b.append(nx.to_numpy(A))
lst_name.append('stack')
+ A = nx.reshape(Mb, (5, 3, 2))
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('reshape')
+
lst_tot.append(lst_b)
lst_np = lst_tot[0]
diff --git a/test/test_optim.py b/test/test_optim.py
index 94995d5..4efd9b1 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -8,7 +8,7 @@ import numpy as np
import ot
-def test_conditional_gradient():
+def test_conditional_gradient(nx):
n_bins = 100 # nb bins
np.random.seed(0)
@@ -29,15 +29,25 @@ def test_conditional_gradient():
def df(G):
return G
+ def fb(G):
+ return 0.5 * nx.sum(G ** 2)
+
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ Mb = nx.from_numpy(M, type_as=ab)
+
reg = 1e-1
G, log = ot.optim.cg(a, b, M, reg, f, df, verbose=True, log=True)
+ Gb, log = ot.optim.cg(ab, bb, Mb, reg, fb, df, verbose=True, log=True)
+ Gb = nx.to_numpy(Gb)
- np.testing.assert_allclose(a, G.sum(1))
- np.testing.assert_allclose(b, G.sum(0))
+ np.testing.assert_allclose(Gb, G)
+ np.testing.assert_allclose(a, Gb.sum(1))
+ np.testing.assert_allclose(b, Gb.sum(0))
-def test_conditional_gradient_itermax():
+def test_conditional_gradient_itermax(nx):
n = 100 # nb samples
mu_s = np.array([0, 0])
@@ -61,16 +71,27 @@ def test_conditional_gradient_itermax():
def df(G):
return G
+ def fb(G):
+ return 0.5 * nx.sum(G ** 2)
+
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ Mb = nx.from_numpy(M, type_as=ab)
+
reg = 1e-1
G, log = ot.optim.cg(a, b, M, reg, f, df, numItermaxEmd=10000,
verbose=True, log=True)
+ Gb, log = ot.optim.cg(ab, bb, Mb, reg, fb, df, numItermaxEmd=10000,
+ verbose=True, log=True)
+ Gb = nx.to_numpy(Gb)
- np.testing.assert_allclose(a, G.sum(1))
- np.testing.assert_allclose(b, G.sum(0))
+ np.testing.assert_allclose(Gb, G)
+ np.testing.assert_allclose(a, Gb.sum(1))
+ np.testing.assert_allclose(b, Gb.sum(0))
-def test_generalized_conditional_gradient():
+def test_generalized_conditional_gradient(nx):
n_bins = 100 # nb bins
np.random.seed(0)
@@ -91,13 +112,23 @@ def test_generalized_conditional_gradient():
def df(G):
return G
+ def fb(G):
+ return 0.5 * nx.sum(G ** 2)
+
reg1 = 1e-3
reg2 = 1e-1
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ Mb = nx.from_numpy(M, type_as=ab)
+
G, log = ot.optim.gcg(a, b, M, reg1, reg2, f, df, verbose=True, log=True)
+ Gb, log = ot.optim.gcg(ab, bb, Mb, reg1, reg2, fb, df, verbose=True, log=True)
+ Gb = nx.to_numpy(Gb)
- np.testing.assert_allclose(a, G.sum(1), atol=1e-05)
- np.testing.assert_allclose(b, G.sum(0), atol=1e-05)
+ np.testing.assert_allclose(Gb, G)
+ np.testing.assert_allclose(a, Gb.sum(1), atol=1e-05)
+ np.testing.assert_allclose(b, Gb.sum(0), atol=1e-05)
def test_solve_1d_linesearch_quad_funct():
@@ -106,24 +137,31 @@ def test_solve_1d_linesearch_quad_funct():
np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 0.5, 0), 1)
-def test_line_search_armijo():
+def test_line_search_armijo(nx):
xk = np.array([[0.25, 0.25], [0.25, 0.25]])
pk = np.array([[-0.25, 0.25], [0.25, -0.25]])
gfk = np.array([[23.04273441, 23.0449082], [23.04273441, 23.0449082]])
old_fval = -123
# Should not throw an exception and return None for alpha
- alpha, _, _ = ot.optim.line_search_armijo(lambda x: 1, xk, pk, gfk, old_fval)
+ alpha, a, b = ot.optim.line_search_armijo(
+ lambda x: 1, nx.from_numpy(xk), nx.from_numpy(pk), nx.from_numpy(gfk), old_fval
+ )
+ alpha_np, anp, bnp = ot.optim.line_search_armijo(
+ lambda x: 1, xk, pk, gfk, old_fval
+ )
+ assert a == anp
+ assert b == bnp
assert alpha is None
# check line search armijo
def f(x):
- return np.sum((x - 5.0) ** 2)
+ return nx.sum((x - 5.0) ** 2)
def grad(x):
return 2 * (x - 5.0)
- xk = np.array([[[-5.0, -5.0]]])
- pk = np.array([[[100.0, 100.0]]])
+ xk = nx.from_numpy(np.array([[[-5.0, -5.0]]]))
+ pk = nx.from_numpy(np.array([[[100.0, 100.0]]]))
gfk = grad(xk)
old_fval = f(xk)
@@ -132,10 +170,18 @@ def test_line_search_armijo():
np.testing.assert_allclose(alpha, 0.1)
# check the case where the direction is not far enough
- pk = np.array([[[3.0, 3.0]]])
+ pk = nx.from_numpy(np.array([[[3.0, 3.0]]]))
alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval, alpha0=1.0)
np.testing.assert_allclose(alpha, 1.0)
- # check the case where the checking the wrong direction
+ # check the case where checking the wrong direction
alpha, _, _ = ot.optim.line_search_armijo(f, xk, -pk, gfk, old_fval)
assert alpha <= 0
+
+ # check the case where the point is not a vector
+ xk = nx.from_numpy(np.array(-5.0))
+ pk = nx.from_numpy(np.array(100.0))
+ gfk = grad(xk)
+ old_fval = f(xk)
+ alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval)
+ np.testing.assert_allclose(alpha, 0.1)
diff --git a/test/test_ot.py b/test/test_ot.py
index 3e953dc..4dfc510 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -12,9 +12,7 @@ from scipy.stats import wasserstein_distance
import ot
from ot.datasets import make_1D_gauss as gauss
-from ot.backend import get_backend_list, torch
-
-backend_list = get_backend_list()
+from ot.backend import torch
def test_emd_dimension_and_mass_mismatch():
@@ -37,7 +35,6 @@ def test_emd_dimension_and_mass_mismatch():
np.testing.assert_raises(AssertionError, ot.emd, a, b, M)
-@pytest.mark.parametrize('nx', backend_list)
def test_emd_backends(nx):
n_samples = 100
n_features = 2
@@ -59,7 +56,6 @@ def test_emd_backends(nx):
np.allclose(G, nx.to_numpy(Gb))
-@pytest.mark.parametrize('nx', backend_list)
def test_emd2_backends(nx):
n_samples = 100
n_features = 2
diff --git a/test/test_utils.py b/test/test_utils.py
index 76b1faa..60ad5d3 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -4,17 +4,11 @@
#
# License: MIT License
-import pytest
import ot
import numpy as np
import sys
-from ot.backend import get_backend_list
-backend_list = get_backend_list()
-
-
-@pytest.mark.parametrize('nx', backend_list)
def test_proj_simplex(nx):
n = 10
rng = np.random.RandomState(0)
@@ -119,7 +113,6 @@ def test_dist():
np.testing.assert_allclose(D, D3, atol=1e-14)
-@ pytest.mark.parametrize('nx', backend_list)
def test_dist_backends(nx):
n = 100