summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorSlasnista <stan.chambon@gmail.com>2017-07-28 08:00:35 +0200
committerSlasnista <stan.chambon@gmail.com>2017-07-28 08:00:35 +0200
commit553a45678c829896cbb076b8a89934525431c62c (patch)
tree117e1404a6801a81c0b1e66cc81d22471c62fcbb /ot/da.py
parent7638d019b43e52d17600cac653939e7cd807478c (diff)
remove linewidth error message
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py150
1 files changed, 108 insertions, 42 deletions
diff --git a/ot/da.py b/ot/da.py
index 4f9bce5..1dd4011 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -17,14 +17,18 @@ from .optim import cg
from .optim import gcg
-def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, numInnerItermax=200, stopInnerThr=1e-9, verbose=False, log=False):
+def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
+ numInnerItermax=200, stopInnerThr=1e-9, verbose=False,
+ log=False):
"""
- Solve the entropic regularization optimal transport problem with nonconvex group lasso regularization
+ Solve the entropic regularization optimal transport problem with nonconvex
+ group lasso regularization
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma)+ \eta \Omega_g(\gamma)
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma)
+ + \eta \Omega_g(\gamma)
s.t. \gamma 1 = a
@@ -34,11 +38,16 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, numInnerIte
where :
- M is the (ns,nt) metric cost matrix
- - :math:`\Omega_e` is the entropic regularization term :math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - :math:`\Omega_g` is the group lasso regulaization term :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^{1/2}_1` where :math:`\mathcal{I}_c` are the index of samples from class c in the source domain.
+ - :math:`\Omega_e` is the entropic regularization term
+ :math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\Omega_g` is the group lasso regulaization term
+ :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^{1/2}_1`
+ where :math:`\mathcal{I}_c` are the index of samples from class c
+ in the source domain.
- a and b are source and target weights (sum to 1)
- The algorithm used for solving the problem is the generalised conditional gradient as proposed in [5]_ [7]_
+ The algorithm used for solving the problem is the generalised conditional
+ gradient as proposed in [5]_ [7]_
Parameters
@@ -78,8 +87,13 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, numInnerIte
References
----------
- .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
- .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567.
+ .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
+ "Optimal Transport for Domain Adaptation," in IEEE
+ Transactions on Pattern Analysis and Machine Intelligence ,
+ vol.PP, no.99, pp.1-1
+ .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015).
+ Generalized conditional gradient: analysis of convergence
+ and applications. arXiv preprint arXiv:1510.06567.
See Also
--------
@@ -114,14 +128,18 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, numInnerIte
return transp
-def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, numInnerItermax=200, stopInnerThr=1e-9, verbose=False, log=False):
+def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
+ numInnerItermax=200, stopInnerThr=1e-9, verbose=False,
+ log=False):
"""
- Solve the entropic regularization optimal transport problem with group lasso regularization
+ Solve the entropic regularization optimal transport problem with group
+ lasso regularization
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma)+ \eta \Omega_g(\gamma)
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma)+
+ \eta \Omega_g(\gamma)
s.t. \gamma 1 = a
@@ -131,11 +149,16 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, numInnerIte
where :
- M is the (ns,nt) metric cost matrix
- - :math:`\Omega_e` is the entropic regularization term :math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - :math:`\Omega_g` is the group lasso regulaization term :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^2` where :math:`\mathcal{I}_c` are the index of samples from class c in the source domain.
+ - :math:`\Omega_e` is the entropic regularization term
+ :math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\Omega_g` is the group lasso regulaization term
+ :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^2`
+ where :math:`\mathcal{I}_c` are the index of samples from class
+ c in the source domain.
- a and b are source and target weights (sum to 1)
- The algorithm used for solving the problem is the generalised conditional gradient as proposed in [5]_ [7]_
+ The algorithm used for solving the problem is the generalised conditional
+ gradient as proposed in [5]_ [7]_
Parameters
@@ -175,8 +198,12 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, numInnerIte
References
----------
- .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
- .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567.
+ .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
+ "Optimal Transport for Domain Adaptation," in IEEE Transactions
+ on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
+ .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015).
+ Generalized conditional gradient: analysis of convergence and
+ applications. arXiv preprint arXiv:1510.06567.
See Also
--------
@@ -203,16 +230,22 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, numInnerIte
W[labels_a == lab, i] = temp / n
return W
- return gcg(a, b, M, reg, eta, f, df, G0=None, numItermax=numItermax, numInnerItermax=numInnerItermax, stopThr=stopInnerThr, verbose=verbose, log=log)
+ return gcg(a, b, M, reg, eta, f, df, G0=None, numItermax=numItermax,
+ numInnerItermax=numInnerItermax, stopThr=stopInnerThr,
+ verbose=verbose, log=log)
-def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, verbose2=False, numItermax=100, numInnerItermax=10, stopInnerThr=1e-6, stopThr=1e-5, log=False, **kwargs):
+def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
+ verbose2=False, numItermax=100, numInnerItermax=10,
+ stopInnerThr=1e-6, stopThr=1e-5, log=False,
+ **kwargs):
"""Joint OT and linear mapping estimation as proposed in [8]
The function solves the following optimization problem:
.. math::
- \min_{\gamma,L}\quad \|L(X_s) -n_s\gamma X_t\|^2_F + \mu<\gamma,M>_F + \eta \|L -I\|^2_F
+ \min_{\gamma,L}\quad \|L(X_s) -n_s\gamma X_t\|^2_F +
+ \mu<\gamma,M>_F + \eta \|L -I\|^2_F
s.t. \gamma 1 = a
@@ -221,8 +254,10 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
\gamma\geq 0
where :
- - M is the (ns,nt) squared euclidean cost matrix between samples in Xs and Xt (scaled by ns)
- - :math:`L` is a dxd linear operator that approximates the barycentric mapping
+ - M is the (ns,nt) squared euclidean cost matrix between samples in
+ Xs and Xt (scaled by ns)
+ - :math:`L` is a dxd linear operator that approximates the barycentric
+ mapping
- :math:`I` is the identity matrix (neutral linear mapping)
- a and b are uniform source and target weights
@@ -277,7 +312,9 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
References
----------
- .. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for discrete optimal transport", Neural Information Processing Systems (NIPS), 2016.
+ .. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard,
+ "Mapping estimation for discrete optimal transport",
+ Neural Information Processing Systems (NIPS), 2016.
See Also
--------
@@ -384,13 +421,18 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
return G, L
-def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', sigma=1, bias=False, verbose=False, verbose2=False, numItermax=100, numInnerItermax=10, stopInnerThr=1e-6, stopThr=1e-5, log=False, **kwargs):
+def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
+ sigma=1, bias=False, verbose=False, verbose2=False,
+ numItermax=100, numInnerItermax=10,
+ stopInnerThr=1e-6, stopThr=1e-5, log=False,
+ **kwargs):
"""Joint OT and nonlinear mapping estimation with kernels as proposed in [8]
The function solves the following optimization problem:
.. math::
- \min_{\gamma,L\in\mathcal{H}}\quad \|L(X_s) -n_s\gamma X_t\|^2_F + \mu<\gamma,M>_F + \eta \|L\|^2_\mathcal{H}
+ \min_{\gamma,L\in\mathcal{H}}\quad \|L(X_s) -
+ n_s\gamma X_t\|^2_F + \mu<\gamma,M>_F + \eta \|L\|^2_\mathcal{H}
s.t. \gamma 1 = a
@@ -399,8 +441,10 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', sigm
\gamma\geq 0
where :
- - M is the (ns,nt) squared euclidean cost matrix between samples in Xs and Xt (scaled by ns)
- - :math:`L` is a ns x d linear operator on a kernel matrix that approximates the barycentric mapping
+ - M is the (ns,nt) squared euclidean cost matrix between samples in
+ Xs and Xt (scaled by ns)
+ - :math:`L` is a ns x d linear operator on a kernel matrix that
+ approximates the barycentric mapping
- a and b are uniform source and target weights
The problem consist in solving jointly an optimal transport matrix
@@ -458,7 +502,9 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', sigm
References
----------
- .. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for discrete optimal transport", Neural Information Processing Systems (NIPS), 2016.
+ .. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard,
+ "Mapping estimation for discrete optimal transport",
+ Neural Information Processing Systems (NIPS), 2016.
See Also
--------
@@ -593,7 +639,9 @@ class OTDA(object):
References
----------
- .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
+ .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
+ "Optimal Transport for Domain Adaptation," in IEEE Transactions on
+ Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
"""
@@ -606,7 +654,8 @@ class OTDA(object):
self.computed = False
def fit(self, xs, xt, ws=None, wt=None, norm=None):
- """ Fit domain adaptation between samples is xs and xt (with optional weights)"""
+ """Fit domain adaptation between samples is xs and xt
+ (with optional weights)"""
self.xs = xs
self.xt = xt
@@ -669,7 +718,9 @@ class OTDA(object):
References
----------
- .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
+ .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
+ Regularized discrete optimal transport. SIAM Journal on Imaging
+ Sciences, 7(3), 1853-1882.
"""
if direction > 0: # >0 then source to target
@@ -708,10 +759,12 @@ class OTDA(object):
class OTDA_sinkhorn(OTDA):
- """Class for domain adaptation with optimal transport with entropic regularization"""
+ """Class for domain adaptation with optimal transport with entropic
+ regularization"""
def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs):
- """ Fit regularized domain adaptation between samples is xs and xt (with optional weights)"""
+ """Fit regularized domain adaptation between samples is xs and xt
+ (with optional weights)"""
self.xs = xs
self.xt = xt
@@ -731,10 +784,14 @@ class OTDA_sinkhorn(OTDA):
class OTDA_lpl1(OTDA):
- """Class for domain adaptation with optimal transport with entropic and group regularization"""
+ """Class for domain adaptation with optimal transport with entropic and
+ group regularization"""
- def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None, **kwargs):
- """ Fit regularized domain adaptation between samples is xs and xt (with optional weights), See ot.da.sinkhorn_lpl1_mm for fit parameters"""
+ def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
+ **kwargs):
+ """Fit regularized domain adaptation between samples is xs and xt
+ (with optional weights), See ot.da.sinkhorn_lpl1_mm for fit
+ parameters"""
self.xs = xs
self.xt = xt
@@ -754,10 +811,14 @@ class OTDA_lpl1(OTDA):
class OTDA_l1l2(OTDA):
- """Class for domain adaptation with optimal transport with entropic and group lasso regularization"""
+ """Class for domain adaptation with optimal transport with entropic
+ and group lasso regularization"""
- def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None, **kwargs):
- """ Fit regularized domain adaptation between samples is xs and xt (with optional weights), See ot.da.sinkhorn_lpl1_gl for fit parameters"""
+ def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
+ **kwargs):
+ """Fit regularized domain adaptation between samples is xs and xt
+ (with optional weights), See ot.da.sinkhorn_lpl1_gl for fit
+ parameters"""
self.xs = xs
self.xt = xt
@@ -777,7 +838,9 @@ class OTDA_l1l2(OTDA):
class OTDA_mapping_linear(OTDA):
- """Class for optimal transport with joint linear mapping estimation as in [8]"""
+ """Class for optimal transport with joint linear mapping estimation as in
+ [8]
+ """
def __init__(self):
""" Class initialization"""
@@ -820,9 +883,11 @@ class OTDA_mapping_linear(OTDA):
class OTDA_mapping_kernel(OTDA_mapping_linear):
- """Class for optimal transport with joint nonlinear mapping estimation as in [8]"""
+ """Class for optimal transport with joint nonlinear mapping
+ estimation as in [8]"""
- def fit(self, xs, xt, mu=1, eta=1, bias=False, kerneltype='gaussian', sigma=1, **kwargs):
+ def fit(self, xs, xt, mu=1, eta=1, bias=False, kerneltype='gaussian',
+ sigma=1, **kwargs):
""" Fit domain adaptation between samples is xs and xt """
self.xs = xs
self.xt = xt
@@ -843,7 +908,8 @@ class OTDA_mapping_kernel(OTDA_mapping_linear):
if self.computed:
K = kernel(
- x, self.xs, method=self.kernel, sigma=self.sigma, **self.kwargs)
+ x, self.xs, method=self.kernel, sigma=self.sigma,
+ **self.kwargs)
if self.bias:
K = np.hstack((K, np.ones((x.shape[0], 1))))
return K.dot(self.L)