summaryrefslogtreecommitdiff
path: root/ot/optim.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/optim.py')
-rw-r--r--ot/optim.py460
1 files changed, 260 insertions, 200 deletions
diff --git a/ot/optim.py b/ot/optim.py
index 5a1d605..201f898 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -1,11 +1,11 @@
# -*- coding: utf-8 -*-
"""
-Generic solvers for regularized OT
+Generic solvers for regularized OT or its semi-relaxed version.
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
# Titouan Vayer <titouan.vayer@irisa.fr>
-#
+# Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
# License: MIT License
import numpy as np
@@ -27,7 +27,7 @@ with warnings.catch_warnings():
def line_search_armijo(
f, xk, pk, gfk, old_fval, args=(), c1=1e-4,
- alpha0=0.99, alpha_min=None, alpha_max=None
+ alpha0=0.99, alpha_min=None, alpha_max=None, nx=None, **kwargs
):
r"""
Armijo linesearch function that works with matrices
@@ -57,7 +57,8 @@ def line_search_armijo(
minimum value for alpha
alpha_max : float, optional
maximum value for alpha
-
+ nx : backend, optional
+ If let to its default value None, a backend test will be conducted.
Returns
-------
alpha : float
@@ -68,9 +69,9 @@ def line_search_armijo(
loss value at step alpha
"""
-
- xk, pk, gfk = list_to_array(xk, pk, gfk)
- nx = get_backend(xk, pk)
+ if nx is None:
+ xk, pk, gfk = list_to_array(xk, pk, gfk)
+ nx = get_backend(xk, pk)
if len(xk.shape) == 0:
xk = nx.reshape(xk, (-1,))
@@ -98,97 +99,38 @@ def line_search_armijo(
return float(alpha), fc[0], phi1
-def solve_linesearch(
- cost, G, deltaG, Mi, f_val, armijo=True, C1=None, C2=None,
- reg=None, Gc=None, constC=None, M=None, alpha_min=None, alpha_max=None
-):
- """
- Solve the linesearch in the FW iterations
-
- Parameters
- ----------
- cost : method
- Cost in the FW for the linesearch
- G : array-like, shape(ns,nt)
- The transport map at a given iteration of the FW
- deltaG : array-like (ns,nt)
- Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration
- Mi : array-like (ns,nt)
- Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost
- f_val : float
- Value of the cost at `G`
- armijo : bool, optional
- If True the steps of the line-search is found via an armijo research. Else closed form is used.
- If there is convergence issues use False.
- C1 : array-like (ns,ns), optional
- Structure matrix in the source domain. Only used and necessary when armijo=False
- C2 : array-like (nt,nt), optional
- Structure matrix in the target domain. Only used and necessary when armijo=False
- reg : float, optional
- Regularization parameter. Only used and necessary when armijo=False
- Gc : array-like (ns,nt)
- Optimal map found by linearization in the FW algorithm. Only used and necessary when armijo=False
- constC : array-like (ns,nt)
- Constant for the gromov cost. See :ref:`[24] <references-solve-linesearch>`. Only used and necessary when armijo=False
- M : array-like (ns,nt), optional
- Cost matrix between the features. Only used and necessary when armijo=False
- alpha_min : float, optional
- Minimum value for alpha
- alpha_max : float, optional
- Maximum value for alpha
-
- Returns
- -------
- alpha : float
- The optimal step size of the FW
- fc : int
- nb of function call. Useless here
- f_val : float
- The value of the cost for the next iteration
+def generic_conditional_gradient(a, b, M, f, df, reg1, reg2, lp_solver, line_search, G0=None,
+ numItermax=200, stopThr=1e-9,
+ stopThr2=1e-9, verbose=False, log=False, **kwargs):
+ r"""
+ Solve the general regularized OT problem or its semi-relaxed version with
+ conditional gradient or generalized conditional gradient depending on the
+ provided linear program solver.
+ The function solves the following optimization problem if set as a conditional gradient:
- .. _references-solve-linesearch:
- References
- ----------
- .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas
- "Optimal Transport for structured data with application on graphs"
- International Conference on Machine Learning (ICML). 2019.
- """
- if armijo:
- alpha, fc, f_val = line_search_armijo(
- cost, G, deltaG, Mi, f_val, alpha_min=alpha_min, alpha_max=alpha_max
- )
- else: # requires symetric matrices
- G, deltaG, C1, C2, constC, M = list_to_array(G, deltaG, C1, C2, constC, M)
- if isinstance(M, int) or isinstance(M, float):
- nx = get_backend(G, deltaG, C1, C2, constC)
- else:
- nx = get_backend(G, deltaG, C1, C2, constC, M)
+ .. math::
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg_1} \cdot f(\gamma)
- dot = nx.dot(nx.dot(C1, deltaG), C2)
- a = -2 * reg * nx.sum(dot * deltaG)
- b = nx.sum((M + reg * constC) * deltaG) - 2 * reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2) * deltaG))
- c = cost(G)
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
- alpha = solve_1d_linesearch_quad(a, b, c)
- if alpha_min is not None or alpha_max is not None:
- alpha = np.clip(alpha, alpha_min, alpha_max)
- fc = None
- f_val = cost(G + alpha * deltaG)
+ \gamma^T \mathbf{1} &= \mathbf{b} (optional constraint)
- return alpha, fc, f_val
+ \gamma &\geq 0
+ where :
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`f` is the regularization term (and `df` is its gradient)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
-def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
- stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs):
- r"""
- Solve the general regularized OT problem with conditional gradient
+ The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] <references-cg>`
- The function solves the following optimization problem:
+ The function solves the following optimization problem if set a generalized conditional gradient:
.. math::
\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
- \mathrm{reg} \cdot f(\gamma)
+ \mathrm{reg_1}\cdot f(\gamma) + \mathrm{reg_2}\cdot\Omega(\gamma)
s.t. \ \gamma \mathbf{1} &= \mathbf{a}
@@ -197,29 +139,39 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
\gamma &\geq 0
where :
- - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
- - :math:`f` is the regularization term (and `df` is its gradient)
- - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
-
- The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] <references-cg>`
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ The algorithm used for solving the problem is the generalized conditional gradient as discussed in :ref:`[5, 7] <references-gcg>`
Parameters
----------
a : array-like, shape (ns,)
samples weights in the source domain
b : array-like, shape (nt,)
- samples in the target domain
+ samples weights in the target domain
M : array-like, shape (ns, nt)
loss matrix
- reg : float
+ f : function
+ Regularization function taking a transportation matrix as argument
+ df: function
+ Gradient of the regularization function taking a transportation matrix as argument
+ reg1 : float
Regularization term >0
+ reg2 : float,
+ Entropic Regularization term >0. Ignored if set to None.
+ lp_solver: function,
+ linear program solver for direction finding of the (generalized) conditional gradient.
+ If set to emd will solve the general regularized OT problem using cg.
+ If set to lp_semi_relaxed_OT will solve the general regularized semi-relaxed OT problem using cg.
+ If set to sinkhorn will solve the general regularized OT problem using generalized cg.
+ line_search: function,
+ Function to find the optimal step. Currently used instances are:
+ line_search_armijo (generic solver). solve_gromov_linesearch for (F)GW problem.
+ solve_semirelaxed_gromov_linesearch for sr(F)GW problem. gcg_linesearch for the Generalized cg.
G0 : array-like, shape (ns,nt), optional
initial guess (default is indep joint density)
numItermax : int, optional
Max number of iterations
- numItermaxEmd : int, optional
- Max number of iterations for emd
stopThr : float, optional
Stop threshold on the relative variation (>0)
stopThr2 : float, optional
@@ -240,16 +192,20 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
.. _references-cg:
+ .. _references_gcg:
References
----------
.. [1] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
+ .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
+
+ .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567.
+
See Also
--------
ot.lp.emd : Unregularized optimal ransport
ot.bregman.sinkhorn : Entropic regularized optimal transport
-
"""
a, b, M, G0 = list_to_array(a, b, M, G0)
if isinstance(M, int) or isinstance(M, float):
@@ -265,42 +221,45 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
if G0 is None:
G = nx.outer(a, b)
else:
- G = G0
-
- def cost(G):
- return nx.sum(M * G) + reg * f(G)
+ # to not change G0 in place.
+ G = nx.copy(G0)
- f_val = cost(G)
+ if reg2 is None:
+ def cost(G):
+ return nx.sum(M * G) + reg1 * f(G)
+ else:
+ def cost(G):
+ return nx.sum(M * G) + reg1 * f(G) + reg2 * nx.sum(G * nx.log(G))
+ cost_G = cost(G)
if log:
- log['loss'].append(f_val)
+ log['loss'].append(cost_G)
it = 0
if verbose:
print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
- print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0))
+ print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, cost_G, 0, 0))
while loop:
it += 1
- old_fval = f_val
-
+ old_cost_G = cost_G
# problem linearization
- Mi = M + reg * df(G)
+ Mi = M + reg1 * df(G)
+
+ if not (reg2 is None):
+ Mi = Mi + reg2 * (1 + nx.log(G))
# set M positive
- Mi += nx.min(Mi)
+ Mi = Mi + nx.min(Mi)
# solve linear program
- Gc, logemd = emd(a, b, Mi, numItermax=numItermaxEmd, log=True)
+ Gc, innerlog_ = lp_solver(a, b, Mi, **kwargs)
+ # line search
deltaG = Gc - G
- # line search
- alpha, fc, f_val = solve_linesearch(
- cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc,
- alpha_min=0., alpha_max=1., **kwargs
- )
+ alpha, fc, cost_G = line_search(cost, G, deltaG, Mi, cost_G, **kwargs)
G = G + alpha * deltaG
@@ -308,29 +267,197 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
if it >= numItermax:
loop = 0
- abs_delta_fval = abs(f_val - old_fval)
- relative_delta_fval = abs_delta_fval / abs(f_val)
- if relative_delta_fval < stopThr or abs_delta_fval < stopThr2:
+ abs_delta_cost_G = abs(cost_G - old_cost_G)
+ relative_delta_cost_G = abs_delta_cost_G / abs(cost_G)
+ if relative_delta_cost_G < stopThr or abs_delta_cost_G < stopThr2:
loop = 0
if log:
- log['loss'].append(f_val)
+ log['loss'].append(cost_G)
if verbose:
if it % 20 == 0:
print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
- print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval))
+ print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, cost_G, relative_delta_cost_G, abs_delta_cost_G))
if log:
- log.update(logemd)
+ log.update(innerlog_)
return G, log
else:
return G
+def cg(a, b, M, reg, f, df, G0=None, line_search=line_search_armijo,
+ numItermax=200, numItermaxEmd=100000, stopThr=1e-9, stopThr2=1e-9,
+ verbose=False, log=False, **kwargs):
+ r"""
+ Solve the general regularized OT problem with conditional gradient
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg} \cdot f(\gamma)
+
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
+
+ \gamma^T \mathbf{1} &= \mathbf{b}
+
+ \gamma &\geq 0
+ where :
+
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`f` is the regularization term (and `df` is its gradient)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
+
+ The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] <references-cg>`
+
+
+ Parameters
+ ----------
+ a : array-like, shape (ns,)
+ samples weights in the source domain
+ b : array-like, shape (nt,)
+ samples in the target domain
+ M : array-like, shape (ns, nt)
+ loss matrix
+ reg : float
+ Regularization term >0
+ G0 : array-like, shape (ns,nt), optional
+ initial guess (default is indep joint density)
+ line_search: function,
+ Function to find the optimal step.
+ Default is line_search_armijo.
+ numItermax : int, optional
+ Max number of iterations
+ numItermaxEmd : int, optional
+ Max number of iterations for emd
+ stopThr : float, optional
+ Stop threshold on the relative variation (>0)
+ stopThr2 : float, optional
+ Stop threshold on the absolute variation (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ **kwargs : dict
+ Parameters for linesearch
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ .. _references-cg:
+ References
+ ----------
+
+ .. [1] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized optimal ransport
+ ot.bregman.sinkhorn : Entropic regularized optimal transport
+
+ """
+
+ def lp_solver(a, b, M, **kwargs):
+ return emd(a, b, M, numItermaxEmd, log=True)
+
+ return generic_conditional_gradient(a, b, M, f, df, reg, None, lp_solver, line_search, G0=G0,
+ numItermax=numItermax, stopThr=stopThr,
+ stopThr2=stopThr2, verbose=verbose, log=log, **kwargs)
+
+
+def semirelaxed_cg(a, b, M, reg, f, df, G0=None, line_search=line_search_armijo,
+ numItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs):
+ r"""
+ Solve the general regularized and semi-relaxed OT problem with conditional gradient
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg} \cdot f(\gamma)
+
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
+
+ \gamma &\geq 0
+ where :
+
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`f` is the regularization term (and `df` is its gradient)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
+
+ The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] <references-cg>`
+
+
+ Parameters
+ ----------
+ a : array-like, shape (ns,)
+ samples weights in the source domain
+ b : array-like, shape (nt,)
+ currently estimated samples weights in the target domain
+ M : array-like, shape (ns, nt)
+ loss matrix
+ reg : float
+ Regularization term >0
+ G0 : array-like, shape (ns,nt), optional
+ initial guess (default is indep joint density)
+ line_search: function,
+ Function to find the optimal step.
+ Default is the armijo line-search.
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on the relative variation (>0)
+ stopThr2 : float, optional
+ Stop threshold on the absolute variation (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ **kwargs : dict
+ Parameters for linesearch
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ .. _references-cg:
+ References
+ ----------
+
+ .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.
+ "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
+ International Conference on Learning Representations (ICLR), 2021.
+
+ """
+
+ nx = get_backend(a, b)
+
+ def lp_solver(a, b, Mi, **kwargs):
+ # get minimum by rows as binary mask
+ Gc = nx.ones(1, type_as=a) * (Mi == nx.reshape(nx.min(Mi, axis=1), (-1, 1)))
+ Gc *= nx.reshape((a / nx.sum(Gc, axis=1)), (-1, 1))
+ # return by default an empty inner_log
+ return Gc, {}
+
+ return generic_conditional_gradient(a, b, M, f, df, reg, None, lp_solver, line_search, G0=G0,
+ numItermax=numItermax, stopThr=stopThr,
+ stopThr2=stopThr2, verbose=verbose, log=log, **kwargs)
+
+
def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
- numInnerItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False):
+ numInnerItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs):
r"""
Solve the general regularized OT problem with the generalized conditional gradient
@@ -403,81 +530,18 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
ot.optim.cg : conditional gradient
"""
- a, b, M, G0 = list_to_array(a, b, M, G0)
- nx = get_backend(a, b, M)
-
- loop = 1
-
- if log:
- log = {'loss': []}
-
- if G0 is None:
- G = nx.outer(a, b)
- else:
- G = G0
-
- def cost(G):
- return nx.sum(M * G) + reg1 * nx.sum(G * nx.log(G)) + reg2 * f(G)
-
- f_val = cost(G)
- if log:
- log['loss'].append(f_val)
-
- it = 0
-
- if verbose:
- print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
- 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
- print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0))
- while loop:
-
- it += 1
- old_fval = f_val
-
- # problem linearization
- Mi = M + reg2 * df(G)
-
- # solve linear program with Sinkhorn
- # Gc = sinkhorn_stabilized(a,b, Mi, reg1, numItermax = numInnerItermax)
- Gc = sinkhorn(a, b, Mi, reg1, numItermax=numInnerItermax)
-
- deltaG = Gc - G
-
- # line search
- dcost = Mi + reg1 * (1 + nx.log(G)) # ??
- alpha, fc, f_val = line_search_armijo(
- cost, G, deltaG, dcost, f_val, alpha_min=0., alpha_max=1.
- )
-
- G = G + alpha * deltaG
-
- # test convergence
- if it >= numItermax:
- loop = 0
-
- abs_delta_fval = abs(f_val - old_fval)
- relative_delta_fval = abs_delta_fval / abs(f_val)
+ def lp_solver(a, b, Mi, **kwargs):
+ return sinkhorn(a, b, Mi, reg1, numItermax=numInnerItermax, log=True, **kwargs)
- if relative_delta_fval < stopThr or abs_delta_fval < stopThr2:
- loop = 0
+ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
+ return line_search_armijo(cost, G, deltaG, Mi, cost_G, **kwargs)
- if log:
- log['loss'].append(f_val)
+ return generic_conditional_gradient(a, b, M, f, df, reg2, reg1, lp_solver, line_search, G0=G0,
+ numItermax=numItermax, stopThr=stopThr, stopThr2=stopThr2, verbose=verbose, log=log, **kwargs)
- if verbose:
- if it % 20 == 0:
- print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
- 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
- print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval))
- if log:
- return G, log
- else:
- return G
-
-
-def solve_1d_linesearch_quad(a, b, c):
+def solve_1d_linesearch_quad(a, b):
r"""
For any convex or non-convex 1d quadratic function `f`, solve the following problem:
@@ -487,7 +551,7 @@ def solve_1d_linesearch_quad(a, b, c):
Parameters
----------
- a,b,c : float
+ a,b : float or tensors (1,)
The coefficients of the quadratic function
Returns
@@ -495,15 +559,11 @@ def solve_1d_linesearch_quad(a, b, c):
x : float
The optimal value which leads to the minimal cost
"""
- f0 = c
- df0 = b
- f1 = a + f0 + df0
-
if a > 0: # convex
- minimum = min(1, max(0, np.divide(-b, 2.0 * a)))
+ minimum = min(1., max(0., -b / (2.0 * a)))
return minimum
else: # non convex
- if f0 > f1:
- return 1
+ if a + b < 0:
+ return 1.
else:
- return 0
+ return 0.