From 126903381374a1d2f934190b208d134a0495dc65 Mon Sep 17 00:00:00 2001 From: ievred Date: Fri, 17 Apr 2020 16:41:14 +0200 Subject: added regulrization from [6]+fix other issues --- ot/da.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) (limited to 'ot/da.py') diff --git a/ot/da.py b/ot/da.py index be959d6..9e00dce 100644 --- a/ot/da.py +++ b/ot/da.py @@ -16,7 +16,7 @@ import scipy.linalg as linalg from .bregman import sinkhorn, jcpot_barycenter from .lp import emd -from .utils import unif, dist, kernel, cost_normalization, label_normalization, laplacian +from .utils import unif, dist, kernel, cost_normalization, label_normalization, laplacian, dots from .utils import check_params, BaseEstimator from .unbalanced import sinkhorn_unbalanced from .optim import cg @@ -748,7 +748,7 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, return A, b -def emd_laplace(a, b, xs, xt, M, sim, eta, alpha, +def emd_laplace(a, b, xs, xt, M, sim, reg, eta, alpha, numItermax, stopThr, numInnerItermax, stopInnerThr, log=False, verbose=False, **kwargs): r"""Solve the optimal transport problem (OT) with Laplacian regularization @@ -785,6 +785,8 @@ def emd_laplace(a, b, xs, xt, M, sim, eta, alpha, samples in the target domain M : np.ndarray (ns,nt) loss matrix + reg : string + Type of Laplacian regularization eta : float Regularization term for Laplacian regularization alpha : float @@ -844,6 +846,8 @@ def emd_laplace(a, b, xs, xt, M, sim, eta, alpha, sS = (sS + sS.T) / 2 sT = kneighbors_graph(xt, kwargs['nn']).toarray() sT = (sT + sT.T) / 2 + else: + raise ValueError('Unknown similarity type {sim}. Currently supported similarity types are "knn" and "gauss".'.format(sim=sim)) lS = laplacian(sS) lT = laplacian(sT) @@ -852,9 +856,18 @@ def emd_laplace(a, b, xs, xt, M, sim, eta, alpha, return alpha * np.trace(np.dot(xt.T, np.dot(G.T, np.dot(lS, np.dot(G, xt))))) \ + (1 - alpha) * np.trace(np.dot(xs.T, np.dot(G, np.dot(lT, np.dot(G.T, xs))))) + ls2 = lS + lS.T + lt2 = lT + lT.T + xt2 = np.dot(xt, xt.T) + + if reg == 'disp': + Cs = -eta * alpha / xs.shape[0] * dots(ls2, xs, xt.T) + Ct = -eta * (1 - alpha) / xt.shape[0] * dots(xs, xt.T, lt2) + M = M + Cs + Ct + def df(G): - return alpha * np.dot(lS + lS.T, np.dot(G, np.dot(xt, xt.T)))\ - + (1 - alpha) * np.dot(xs, np.dot(xs.T, np.dot(G, lT + lT.T))) + return alpha * np.dot(ls2, np.dot(G, xt2))\ + + (1 - alpha) * np.dot(xs, np.dot(xs.T, np.dot(G, lt2))) return cg(a, b, M, reg=eta, f=f, df=df, G0=None, numItermax=numItermax, numItermaxEmd=numInnerItermax, stopThr=stopThr, stopThr2=stopInnerThr, verbose=verbose, log=log) @@ -1694,6 +1707,9 @@ class EMDLaplaceTransport(BaseTransport): Parameters ---------- + reg_type : string optional (default='pos') + Type of the regularization term: 'pos' and 'disp' for + regularization term defined in [2] and [6], respectively. reg_lap : float, optional (default=1) Laplacian regularization parameter reg_src : float, optional (default=0.5) @@ -1737,11 +1753,12 @@ class EMDLaplaceTransport(BaseTransport): in NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014. """ - def __init__(self, reg_lap=1., reg_src=1., alpha=0.5, + def __init__(self, reg_type='pos', reg_lap=1., reg_src=1., alpha=0.5, metric="sqeuclidean", norm=None, similarity="knn", max_iter=100, tol=1e-9, max_inner_iter=100000, inner_tol=1e-9, log=False, verbose=False, distribution_estimation=distribution_estimation_uniform, out_of_sample_map='ferradans'): + self.reg = reg_type self.reg_lap = reg_lap self.reg_src = reg_src self.alpha = alpha @@ -1785,7 +1802,7 @@ class EMDLaplaceTransport(BaseTransport): super(EMDLaplaceTransport, self).fit(Xs, ys, Xt, yt) returned_ = emd_laplace(a=self.mu_s, b=self.mu_t, xs=self.xs_, - xt=self.xt_, M=self.cost_, sim=self.similarity, eta=self.reg_lap, alpha=self.reg_src, + xt=self.xt_, M=self.cost_, reg=self.reg, sim=self.similarity, eta=self.reg_lap, alpha=self.reg_src, numItermax=self.max_iter, stopThr=self.tol, numInnerItermax=self.max_inner_iter, stopInnerThr=self.inner_tol, log=self.log, verbose=self.verbose) -- cgit v1.2.3