summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py123
1 files changed, 102 insertions, 21 deletions
diff --git a/ot/da.py b/ot/da.py
index 76bc6a3..72ca3ac 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -14,7 +14,6 @@ def indices(a, func):
return [i for (i, val) in enumerate(a) if func(val)]
-
def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerItermax = 200,stopInnerThr=1e-9,verbose=False,log=False):
"""
Solve the entropic regularization optimal transport problem with nonconvex group lasso regularization
@@ -49,15 +48,15 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter
samples in the target domain
M : np.ndarray (ns,nt)
loss matrix
- reg: float
+ reg : float
Regularization term for entropic regularization >0
- eta: float, optional
+ eta : float, optional
Regularization term for group lasso regularization >0
- numItermax: int, optional
+ numItermax : int, optional
Max number of iterations
- numInnerItermax: int, optional
+ numInnerItermax : int, optional
Max number of iterations (inner sinkhorn solver)
- stopInnerThr: float, optional
+ stopInnerThr : float, optional
Stop threshold on error (inner sinkhorn solver) (>0)
verbose : bool, optional
Print information along iterations
@@ -67,9 +66,9 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter
Returns
-------
- gamma: (ns x nt) ndarray
+ gamma : (ns x nt) ndarray
Optimal transportation matrix for the given parameters
- log: dict
+ log : dict
log dictionary return only if log==True in parameters
@@ -145,7 +144,10 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos
The problem consist in solving jointly an optimal transport matrix
:math:`\gamma` and a linear mapping that fits the barycentric mapping
- :math:`n_s\gamma X_t`
+ :math:`n_s\gamma X_t`.
+
+ One can also estimate a mapping with constant bias (see supplementary
+ material of [8]) using the bias optional argument.
The algorithm used for solving the problem is the block coordinate
descent that alternates between updates of G (using conditionnal gradient)
@@ -158,19 +160,19 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos
samples in the source domain
xt : np.ndarray (nt,d)
samples in the target domain
- mu: float,optional
+ mu : float,optional
Weight for the linear OT loss (>0)
- eta: float, optional
+ eta : float, optional
Regularization term for the linear mapping L (>0)
- bias: bool,optional
+ bias : bool,optional
Estimate linear mapping with constant bias
- numItermax: int, optional
+ numItermax : int, optional
Max number of BCD iterations
- stopThr: float, optional
+ stopThr : float, optional
Stop threshold on relative loss decrease (>0)
- numInnerItermax: int, optional
+ numInnerItermax : int, optional
Max number of iterations (inner CG solver)
- stopInnerThr: float, optional
+ stopInnerThr : float, optional
Stop threshold on error (inner CG solver) (>0)
verbose : bool, optional
Print information along iterations
@@ -180,11 +182,11 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos
Returns
-------
- gamma: (ns x nt) ndarray
+ gamma : (ns x nt) ndarray
Optimal transportation matrix for the given parameters
- L: (d x d) ndarray
+ L : (d x d) ndarray
Linear mapping matrix (d+1 x d if bias)
- log: dict
+ log : dict
log dictionary return only if log==True in parameters
@@ -291,10 +293,89 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos
def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 10,stopInnerThr=1e-6,stopThr=1e-5,log=False,**kwargs):
- """Joint Ot and mapping estimation (uniform weights and )
+ """Joint OT and nonlinear mapping estimation with kernels as proposed in [8]
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \min_{\gamma,L\in\mathcal{H}}\quad \|L(X_s) -n_s\gamma X_t\|^2_F + \mu<\gamma,M>_F + \eta \|L\|^2_\mathcal{H}
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - M is the (ns,nt) squared euclidean cost matrix between samples in Xs and Xt (scaled by ns)
+ - :math:`L` is a ns x d linear operator on a kernel matrix that approximates the barycentric mapping
+ - a and b are uniform source and target weights
+
+ The problem consist in solving jointly an optimal transport matrix
+ :math:`\gamma` and the nonlinear mapping that fits the barycentric mapping
+ :math:`n_s\gamma X_t`.
+
+ One can also estimate a mapping with constant bias (see supplementary
+ material of [8]) using the bias optional argument.
+
+ The algorithm used for solving the problem is the block coordinate
+ descent that alternates between updates of G (using conditionnal gradient)
+ abd the update of L using a classical kernel least square solver.
+
+
+ Parameters
+ ----------
+ xs : np.ndarray (ns,d)
+ samples in the source domain
+ xt : np.ndarray (nt,d)
+ samples in the target domain
+ mu : float,optional
+ Weight for the linear OT loss (>0)
+ eta : float, optional
+ Regularization term for the linear mapping L (>0)
+ bias : bool,optional
+ Estimate linear mapping with constant bias
+ kerneltype : str,optional
+ kernel used by calling function ot.utils.kernel (gaussian by default)
+ sigma : float, optional
+ Gaussian kernel bandwidth.
+ numItermax : int, optional
+ Max number of BCD iterations
+ stopThr : float, optional
+ Stop threshold on relative loss decrease (>0)
+ numInnerItermax : int, optional
+ Max number of iterations (inner CG solver)
+ stopInnerThr : float, optional
+ Stop threshold on error (inner CG solver) (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ L : (ns x d) ndarray
+ Nonlinear mapping matrix (ns+1 x d if bias)
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for discrete optimal transport", Neural Information Processing Systems (NIPS), 2016.
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.optim.cg : General regularized OT
+
"""
- ns,nt,d=xs.shape[0],xt.shape[0],xt.shape[1]
+ ns,nt=xs.shape[0],xt.shape[0]
K=kernel(xs,xs,method=kerneltype,sigma=sigma)
if bias: