summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py136
1 files changed, 20 insertions, 116 deletions
diff --git a/ot/da.py b/ot/da.py
index 0b9737e..5067a69 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -17,8 +17,9 @@ from .backend import get_backend
from .bregman import sinkhorn, jcpot_barycenter
from .lp import emd
from .utils import unif, dist, kernel, cost_normalization, label_normalization, laplacian, dots
-from .utils import list_to_array, check_params, BaseEstimator
+from .utils import list_to_array, check_params, BaseEstimator, deprecated
from .unbalanced import sinkhorn_unbalanced
+from .gaussian import empirical_bures_wasserstein_mapping
from .optim import cg
from .optim import gcg
@@ -126,8 +127,12 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
W = nx.zeros(M.shape, type_as=M)
for cpt in range(numItermax):
Mreg = M + eta * W
- transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
- stopThr=stopInnerThr)
+ if log:
+ transp, log = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
+ stopThr=stopInnerThr, log=True)
+ else:
+ transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
+ stopThr=stopInnerThr)
# the transport has been computed. Check if classes are really
# separated
W = nx.ones(M.shape, type_as=M)
@@ -136,7 +141,10 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
majs = p * ((majs + epsilon) ** (p - 1))
W[indices_labels[i]] = majs
- return transp
+ if log:
+ return transp, log
+ else:
+ return transp
def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
@@ -672,112 +680,7 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
return G, L
-def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
- wt=None, bias=True, log=False):
- r"""Return OT linear operator between samples.
-
- The function estimates the optimal linear operator that aligns the two
- empirical distributions. This is equivalent to estimating the closed
- form mapping between two Gaussian distributions :math:`\mathcal{N}(\mu_s,\Sigma_s)`
- and :math:`\mathcal{N}(\mu_t,\Sigma_t)` as proposed in
- :ref:`[14] <references-OT-mapping-linear>` and discussed in remark 2.29 in
- :ref:`[15] <references-OT-mapping-linear>`.
-
- The linear operator from source to target :math:`M`
-
- .. math::
- M(\mathbf{x})= \mathbf{A} \mathbf{x} + \mathbf{b}
-
- where :
-
- .. math::
- \mathbf{A} &= \Sigma_s^{-1/2} \left(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2} \right)^{1/2}
- \Sigma_s^{-1/2}
-
- \mathbf{b} &= \mu_t - \mathbf{A} \mu_s
-
- Parameters
- ----------
- xs : array-like (ns,d)
- samples in the source domain
- xt : array-like (nt,d)
- samples in the target domain
- reg : float,optional
- regularization added to the diagonals of covariances (>0)
- ws : array-like (ns,1), optional
- weights for the source samples
- wt : array-like (ns,1), optional
- weights for the target samples
- bias: boolean, optional
- estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True)
- log : bool, optional
- record log if True
-
-
- Returns
- -------
- A : (d, d) array-like
- Linear operator
- b : (1, d) array-like
- bias
- log : dict
- log dictionary return only if log==True in parameters
-
-
- .. _references-OT-mapping-linear:
- References
- ----------
- .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of
- distributions", Journal of Optimization Theory and Applications
- Vol 43, 1984
-
- .. [15] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
- Transport", 2018.
-
-
- """
- xs, xt = list_to_array(xs, xt)
- nx = get_backend(xs, xt)
-
- d = xs.shape[1]
-
- if bias:
- mxs = nx.mean(xs, axis=0)[None, :]
- mxt = nx.mean(xt, axis=0)[None, :]
-
- xs = xs - mxs
- xt = xt - mxt
- else:
- mxs = nx.zeros((1, d), type_as=xs)
- mxt = nx.zeros((1, d), type_as=xs)
-
- if ws is None:
- ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0]
-
- if wt is None:
- wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0]
-
- Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) + reg * nx.eye(d, type_as=xs)
- Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt)
-
- Cs12 = nx.sqrtm(Cs)
- Cs_12 = nx.inv(Cs12)
-
- M0 = nx.sqrtm(dots(Cs12, Ct, Cs12))
-
- A = dots(Cs_12, M0, Cs_12)
-
- b = mxt - nx.dot(mxs, A)
-
- if log:
- log = {}
- log['Cs'] = Cs
- log['Ct'] = Ct
- log['Cs12'] = Cs12
- log['Cs_12'] = Cs_12
- return A, b, log
- else:
- return A, b
+OT_mapping_linear = deprecated(empirical_bures_wasserstein_mapping)
def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, alpha=.5,
@@ -1371,10 +1274,10 @@ class LinearTransport(BaseTransport):
self.mu_t = self.distribution_estimation(Xt)
# coupling estimation
- returned_ = OT_mapping_linear(Xs, Xt, reg=self.reg,
- ws=nx.reshape(self.mu_s, (-1, 1)),
- wt=nx.reshape(self.mu_t, (-1, 1)),
- bias=self.bias, log=self.log)
+ returned_ = empirical_bures_wasserstein_mapping(Xs, Xt, reg=self.reg,
+ ws=nx.reshape(self.mu_s, (-1, 1)),
+ wt=nx.reshape(self.mu_t, (-1, 1)),
+ bias=self.bias, log=self.log)
# deal with the value of log
if self.log:
@@ -1514,12 +1417,13 @@ class SinkhornTransport(BaseTransport):
Sciences, 7(3), 1853-1882.
"""
- def __init__(self, reg_e=1., max_iter=1000,
+ def __init__(self, reg_e=1., method="sinkhorn", max_iter=1000,
tol=10e-9, verbose=False, log=False,
metric="sqeuclidean", norm=None,
distribution_estimation=distribution_estimation_uniform,
out_of_sample_map='ferradans', limit_max=np.infty):
self.reg_e = reg_e
+ self.method = method
self.max_iter = max_iter
self.tol = tol
self.verbose = verbose
@@ -1560,7 +1464,7 @@ class SinkhornTransport(BaseTransport):
# coupling estimation
returned_ = sinkhorn(
a=self.mu_s, b=self.mu_t, M=self.cost_, reg=self.reg_e,
- numItermax=self.max_iter, stopThr=self.tol,
+ method=self.method, numItermax=self.max_iter, stopThr=self.tol,
verbose=self.verbose, log=self.log)
# deal with the value of log