summaryrefslogtreecommitdiff
path: root/ot/optim.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/optim.py')
-rw-r--r--ot/optim.py189
1 files changed, 114 insertions, 75 deletions
diff --git a/ot/optim.py b/ot/optim.py
index b9ca891..bd8ca26 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -12,34 +12,36 @@ import numpy as np
from scipy.optimize.linesearch import scalar_search_armijo
from .lp import emd
from .bregman import sinkhorn
+from ot.utils import list_to_array
+from .backend import get_backend
# The corresponding scipy function does not work for matrices
def line_search_armijo(f, xk, pk, gfk, old_fval,
args=(), c1=1e-4, alpha0=0.99):
- """
+ r"""
Armijo linesearch function that works with matrices
- find an approximate minimum of f(xk+alpha*pk) that satifies the
+ Find an approximate minimum of :math:`f(x_k + \alpha \cdot p_k)` that satisfies the
armijo conditions.
Parameters
----------
f : callable
loss function
- xk : ndarray
+ xk : array-like
initial position
- pk : ndarray
+ pk : array-like
descent direction
- gfk : ndarray
- gradient of f at xk
+ gfk : array-like
+ gradient of `f` at :math:`x_k`
old_fval : float
- loss value at xk
+ loss value at :math:`x_k`
args : tuple, optional
- arguments given to f
+ arguments given to `f`
c1 : float, optional
- c1 const in armijo rule (>0)
+ :math:`c_1` const in armijo rule (>0)
alpha0 : float, optional
initial step (>0)
@@ -53,7 +55,13 @@ def line_search_armijo(f, xk, pk, gfk, old_fval,
loss value at step alpha
"""
- xk = np.atleast_1d(xk)
+
+ xk, pk, gfk = list_to_array(xk, pk, gfk)
+ nx = get_backend(xk, pk)
+
+ if len(xk.shape) == 0:
+ xk = nx.reshape(xk, (-1,))
+
fc = [0]
def phi(alpha1):
@@ -65,10 +73,13 @@ def line_search_armijo(f, xk, pk, gfk, old_fval,
else:
phi0 = old_fval
- derphi0 = np.sum(pk * gfk) # Quickfix for matrices
+ derphi0 = nx.sum(pk * gfk) # Quickfix for matrices
alpha, phi1 = scalar_search_armijo(
phi, phi0, derphi0, c1=c1, alpha0=alpha0)
+ # scalar_search_armijo can return alpha > 1
+ if alpha is not None:
+ alpha = min(1, alpha)
return alpha, fc[0], phi1
@@ -76,55 +87,64 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val,
armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None):
"""
Solve the linesearch in the FW iterations
+
Parameters
----------
cost : method
Cost in the FW for the linesearch
- G : ndarray, shape(ns,nt)
+ G : array-like, shape(ns,nt)
The transport map at a given iteration of the FW
- deltaG : ndarray (ns,nt)
+ deltaG : array-like (ns,nt)
Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration
- Mi : ndarray (ns,nt)
+ Mi : array-like (ns,nt)
Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost
- f_val : float
- Value of the cost at G
+ f_val : float
+ Value of the cost at `G`
armijo : bool, optional
- If True the steps of the line-search is found via an armijo research. Else closed form is used.
- If there is convergence issues use False.
- C1 : ndarray (ns,ns), optional
+ If True the steps of the line-search is found via an armijo research. Else closed form is used.
+ If there is convergence issues use False.
+ C1 : array-like (ns,ns), optional
Structure matrix in the source domain. Only used and necessary when armijo=False
- C2 : ndarray (nt,nt), optional
+ C2 : array-like (nt,nt), optional
Structure matrix in the target domain. Only used and necessary when armijo=False
reg : float, optional
- Regularization parameter. Only used and necessary when armijo=False
- Gc : ndarray (ns,nt)
+ Regularization parameter. Only used and necessary when armijo=False
+ Gc : array-like (ns,nt)
Optimal map found by linearization in the FW algorithm. Only used and necessary when armijo=False
- constC : ndarray (ns,nt)
- Constant for the gromov cost. See [24]. Only used and necessary when armijo=False
- M : ndarray (ns,nt), optional
+ constC : array-like (ns,nt)
+ Constant for the gromov cost. See :ref:`[24] <references-solve-linesearch>`. Only used and necessary when armijo=False
+ M : array-like (ns,nt), optional
Cost matrix between the features. Only used and necessary when armijo=False
+
Returns
-------
alpha : float
- The optimal step size of the FW
+ The optimal step size of the FW
fc : int
- nb of function call. Useless here
- f_val : float
- The value of the cost for the next iteration
+ nb of function call. Useless here
+ f_val : float
+ The value of the cost for the next iteration
+
+
+ .. _references-solve-linesearch:
References
----------
- .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
- and Courty Nicolas
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary RĂ©mi, Tavenard Romain and Courty Nicolas
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
"""
if armijo:
alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val)
else: # requires symetric matrices
- dot1 = np.dot(C1, deltaG)
- dot12 = dot1.dot(C2)
- a = -2 * reg * np.sum(dot12 * deltaG)
- b = np.sum((M + reg * constC) * deltaG) - 2 * reg * (np.sum(dot12 * G) + np.sum(np.dot(C1, G).dot(C2) * deltaG))
+ G, deltaG, C1, C2, constC, M = list_to_array(G, deltaG, C1, C2, constC, M)
+ if isinstance(M, int) or isinstance(M, float):
+ nx = get_backend(G, deltaG, C1, C2, constC)
+ else:
+ nx = get_backend(G, deltaG, C1, C2, constC, M)
+
+ dot = nx.dot(nx.dot(C1, deltaG), C2)
+ a = -2 * reg * nx.sum(dot * deltaG)
+ b = nx.sum((M + reg * constC) * deltaG) - 2 * reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2) * deltaG))
c = cost(G)
alpha = solve_1d_linesearch_quad(a, b, c)
@@ -136,48 +156,49 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val,
def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs):
- """
+ r"""
Solve the general regularized OT problem with conditional gradient
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg*f(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg} \cdot f(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} &= \mathbf{b}
- \gamma\geq 0
+ \gamma &\geq 0
where :
- - M is the (ns,nt) metric cost matrix
- - :math:`f` is the regularization term ( and df is its gradient)
- - a and b are source and target weights (sum to 1)
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`f` is the regularization term (and `df` is its gradient)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
- The algorithm used for solving the problem is conditional gradient as discussed in [1]_
+ The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] <references-cg>`
Parameters
----------
- a : ndarray, shape (ns,)
+ a : array-like, shape (ns,)
samples weights in the source domain
- b : ndarray, shape (nt,)
+ b : array-like, shape (nt,)
samples in the target domain
- M : ndarray, shape (ns, nt)
+ M : array-like, shape (ns, nt)
loss matrix
reg : float
Regularization term >0
- G0 : ndarray, shape (ns,nt), optional
+ G0 : array-like, shape (ns,nt), optional
initial guess (default is indep joint density)
numItermax : int, optional
Max number of iterations
numItermaxEmd : int, optional
Max number of iterations for emd
stopThr : float, optional
- Stop threshol on the relative variation (>0)
+ Stop threshold on the relative variation (>0)
stopThr2 : float, optional
- Stop threshol on the absolute variation (>0)
+ Stop threshold on the absolute variation (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -193,6 +214,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
log dictionary return only if log==True in parameters
+ .. _references-cg:
References
----------
@@ -204,6 +226,11 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
ot.bregman.sinkhorn : Entropic regularized optimal transport
"""
+ a, b, M, G0 = list_to_array(a, b, M, G0)
+ if isinstance(M, int) or isinstance(M, float):
+ nx = get_backend(a, b)
+ else:
+ nx = get_backend(a, b, M)
loop = 1
@@ -211,12 +238,12 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
log = {'loss': []}
if G0 is None:
- G = np.outer(a, b)
+ G = nx.outer(a, b)
else:
G = G0
def cost(G):
- return np.sum(M * G) + reg * f(G)
+ return nx.sum(M * G) + reg * f(G)
f_val = cost(G)
if log:
@@ -237,15 +264,17 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
# problem linearization
Mi = M + reg * df(G)
# set M positive
- Mi += Mi.min()
+ Mi += nx.min(Mi)
# solve linear program
- Gc = emd(a, b, Mi, numItermax=numItermaxEmd)
+ Gc, logemd = emd(a, b, Mi, numItermax=numItermaxEmd, log=True)
deltaG = Gc - G
# line search
alpha, fc, f_val = solve_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs)
+ if alpha is None:
+ alpha = 0.0
G = G + alpha * deltaG
@@ -268,6 +297,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval))
if log:
+ log.update(logemd)
return G, log
else:
return G
@@ -275,51 +305,52 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
numInnerItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False):
- """
+ r"""
Solve the general regularized OT problem with the generalized conditional gradient
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg1\cdot\Omega(\gamma) + reg2\cdot f(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg_1}\cdot\Omega(\gamma) + \mathrm{reg_2}\cdot f(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} &= \mathbf{b}
- \gamma\geq 0
+ \gamma &\geq 0
where :
- - M is the (ns,nt) metric cost matrix
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - :math:`f` is the regularization term ( and df is its gradient)
- - a and b are source and target weights (sum to 1)
+ - :math:`f` is the regularization term (and `df` is its gradient)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
- The algorithm used for solving the problem is the generalized conditional gradient as discussed in [5,7]_
+ The algorithm used for solving the problem is the generalized conditional gradient as discussed in :ref:`[5, 7] <references-gcg>`
Parameters
----------
- a : ndarray, shape (ns,)
+ a : array-like, shape (ns,)
samples weights in the source domain
- b : ndarrayv (nt,)
+ b : array-like, (nt,)
samples in the target domain
- M : ndarray, shape (ns, nt)
+ M : array-like, shape (ns, nt)
loss matrix
reg1 : float
Entropic Regularization term >0
reg2 : float
Second Regularization term >0
- G0 : ndarray, shape (ns, nt), optional
+ G0 : array-like, shape (ns, nt), optional
initial guess (default is indep joint density)
numItermax : int, optional
Max number of iterations
numInnerItermax : int, optional
Max number of iterations of Sinkhorn
stopThr : float, optional
- Stop threshol on the relative variation (>0)
+ Stop threshold on the relative variation (>0)
stopThr2 : float, optional
- Stop threshol on the absolute variation (>0)
+ Stop threshold on the absolute variation (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -332,9 +363,13 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
log : dict
log dictionary return only if log==True in parameters
+
+ .. _references-gcg:
References
----------
+
.. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
+
.. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567.
See Also
@@ -342,6 +377,8 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
ot.optim.cg : conditional gradient
"""
+ a, b, M, G0 = list_to_array(a, b, M, G0)
+ nx = get_backend(a, b, M)
loop = 1
@@ -349,12 +386,12 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
log = {'loss': []}
if G0 is None:
- G = np.outer(a, b)
+ G = nx.outer(a, b)
else:
G = G0
def cost(G):
- return np.sum(M * G) + reg1 * np.sum(G * np.log(G)) + reg2 * f(G)
+ return nx.sum(M * G) + reg1 * nx.sum(G * nx.log(G)) + reg2 * f(G)
f_val = cost(G)
if log:
@@ -382,7 +419,7 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
deltaG = Gc - G
# line search
- dcost = Mi + reg1 * (1 + np.log(G)) # ??
+ dcost = Mi + reg1 * (1 + nx.log(G)) # ??
alpha, fc, f_val = line_search_armijo(cost, G, deltaG, dcost, f_val)
G = G + alpha * deltaG
@@ -413,10 +450,12 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
def solve_1d_linesearch_quad(a, b, c):
- """
- For any convex or non-convex 1d quadratic function f, solve on [0,1] the following problem:
+ r"""
+ For any convex or non-convex 1d quadratic function `f`, solve the following problem:
+
.. math::
- \argmin f(x)=a*x^{2}+b*x+c
+
+ \mathop{\arg \min}_{0 \leq x \leq 1} \quad f(x) = ax^{2} + bx + c
Parameters
----------