summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-11-04 11:45:14 +0100
committerRémi Flamary <remi.flamary@gmail.com>2016-11-04 11:45:14 +0100
commit151aa9b0c92d151bc9ed0edbe2217506652f19ec (patch)
treec3257b116710249c56ab906da0a4d936e2c3c675 /ot/da.py
parent5fe917cd92541c1d869e342a841756cd53927a8a (diff)
doc linear mapping estimation
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py75
1 files changed, 74 insertions, 1 deletions
diff --git a/ot/da.py b/ot/da.py
index 4e5fda2..76bc6a3 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -124,7 +124,80 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter
return transp
def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 10,stopInnerThr=1e-6,stopThr=1e-5,log=False,**kwargs):
- """Joint Ot and mapping estimation (uniform weights and )
+ """Joint OT and linear mapping estimation as proposed in [8]
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \min_{\gamma,L}\quad \|L(X_s) -n_s\gamma X_t\|^2_F + \mu<\gamma,M>_F + \eta \|L -I\|^2_F
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - M is the (ns,nt) squared euclidean cost matrix between samples in Xs and Xt (scaled by ns)
+ - :math:`L` is a dxd linear operator that approximates the barycentric mapping
+ - :math:`I` is the identity matrix (neutral linear mapping)
+ - a and b are uniform source and target weights
+
+ The problem consist in solving jointly an optimal transport matrix
+ :math:`\gamma` and a linear mapping that fits the barycentric mapping
+ :math:`n_s\gamma X_t`
+
+ The algorithm used for solving the problem is the block coordinate
+ descent that alternates between updates of G (using conditionnal gradient)
+ abd the update of L using a classical least square solver.
+
+
+ Parameters
+ ----------
+ xs : np.ndarray (ns,d)
+ samples in the source domain
+ xt : np.ndarray (nt,d)
+ samples in the target domain
+ mu: float,optional
+ Weight for the linear OT loss (>0)
+ eta: float, optional
+ Regularization term for the linear mapping L (>0)
+ bias: bool,optional
+ Estimate linear mapping with constant bias
+ numItermax: int, optional
+ Max number of BCD iterations
+ stopThr: float, optional
+ Stop threshold on relative loss decrease (>0)
+ numInnerItermax: int, optional
+ Max number of iterations (inner CG solver)
+ stopInnerThr: float, optional
+ Stop threshold on error (inner CG solver) (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma: (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ L: (d x d) ndarray
+ Linear mapping matrix (d+1 x d if bias)
+ log: dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for discrete optimal transport", Neural Information Processing Systems (NIPS), 2016.
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.optim.cg : General regularized OT
+
"""
ns,nt,d=xs.shape[0],xt.shape[0],xt.shape[1]