summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorievred <ievgen.redko@univ-st-etienne.fr>2020-04-17 16:41:14 +0200
committerievred <ievgen.redko@univ-st-etienne.fr>2020-04-17 16:41:14 +0200
commit126903381374a1d2f934190b208d134a0495dc65 (patch)
tree98794ce5274cd9153207b389b526305b48494845 /ot/da.py
parent14fbb88333971f575510747fd6e9217ec50d9780 (diff)
added regulrization from [6]+fix other issues
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py29
1 files changed, 23 insertions, 6 deletions
diff --git a/ot/da.py b/ot/da.py
index be959d6..9e00dce 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -16,7 +16,7 @@ import scipy.linalg as linalg
from .bregman import sinkhorn, jcpot_barycenter
from .lp import emd
-from .utils import unif, dist, kernel, cost_normalization, label_normalization, laplacian
+from .utils import unif, dist, kernel, cost_normalization, label_normalization, laplacian, dots
from .utils import check_params, BaseEstimator
from .unbalanced import sinkhorn_unbalanced
from .optim import cg
@@ -748,7 +748,7 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
return A, b
-def emd_laplace(a, b, xs, xt, M, sim, eta, alpha,
+def emd_laplace(a, b, xs, xt, M, sim, reg, eta, alpha,
numItermax, stopThr, numInnerItermax,
stopInnerThr, log=False, verbose=False, **kwargs):
r"""Solve the optimal transport problem (OT) with Laplacian regularization
@@ -785,6 +785,8 @@ def emd_laplace(a, b, xs, xt, M, sim, eta, alpha,
samples in the target domain
M : np.ndarray (ns,nt)
loss matrix
+ reg : string
+ Type of Laplacian regularization
eta : float
Regularization term for Laplacian regularization
alpha : float
@@ -844,6 +846,8 @@ def emd_laplace(a, b, xs, xt, M, sim, eta, alpha,
sS = (sS + sS.T) / 2
sT = kneighbors_graph(xt, kwargs['nn']).toarray()
sT = (sT + sT.T) / 2
+ else:
+ raise ValueError('Unknown similarity type {sim}. Currently supported similarity types are "knn" and "gauss".'.format(sim=sim))
lS = laplacian(sS)
lT = laplacian(sT)
@@ -852,9 +856,18 @@ def emd_laplace(a, b, xs, xt, M, sim, eta, alpha,
return alpha * np.trace(np.dot(xt.T, np.dot(G.T, np.dot(lS, np.dot(G, xt))))) \
+ (1 - alpha) * np.trace(np.dot(xs.T, np.dot(G, np.dot(lT, np.dot(G.T, xs)))))
+ ls2 = lS + lS.T
+ lt2 = lT + lT.T
+ xt2 = np.dot(xt, xt.T)
+
+ if reg == 'disp':
+ Cs = -eta * alpha / xs.shape[0] * dots(ls2, xs, xt.T)
+ Ct = -eta * (1 - alpha) / xt.shape[0] * dots(xs, xt.T, lt2)
+ M = M + Cs + Ct
+
def df(G):
- return alpha * np.dot(lS + lS.T, np.dot(G, np.dot(xt, xt.T)))\
- + (1 - alpha) * np.dot(xs, np.dot(xs.T, np.dot(G, lT + lT.T)))
+ return alpha * np.dot(ls2, np.dot(G, xt2))\
+ + (1 - alpha) * np.dot(xs, np.dot(xs.T, np.dot(G, lt2)))
return cg(a, b, M, reg=eta, f=f, df=df, G0=None, numItermax=numItermax, numItermaxEmd=numInnerItermax,
stopThr=stopThr, stopThr2=stopInnerThr, verbose=verbose, log=log)
@@ -1694,6 +1707,9 @@ class EMDLaplaceTransport(BaseTransport):
Parameters
----------
+ reg_type : string optional (default='pos')
+ Type of the regularization term: 'pos' and 'disp' for
+ regularization term defined in [2] and [6], respectively.
reg_lap : float, optional (default=1)
Laplacian regularization parameter
reg_src : float, optional (default=0.5)
@@ -1737,11 +1753,12 @@ class EMDLaplaceTransport(BaseTransport):
in NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014.
"""
- def __init__(self, reg_lap=1., reg_src=1., alpha=0.5,
+ def __init__(self, reg_type='pos', reg_lap=1., reg_src=1., alpha=0.5,
metric="sqeuclidean", norm=None, similarity="knn", max_iter=100, tol=1e-9,
max_inner_iter=100000, inner_tol=1e-9, log=False, verbose=False,
distribution_estimation=distribution_estimation_uniform,
out_of_sample_map='ferradans'):
+ self.reg = reg_type
self.reg_lap = reg_lap
self.reg_src = reg_src
self.alpha = alpha
@@ -1785,7 +1802,7 @@ class EMDLaplaceTransport(BaseTransport):
super(EMDLaplaceTransport, self).fit(Xs, ys, Xt, yt)
returned_ = emd_laplace(a=self.mu_s, b=self.mu_t, xs=self.xs_,
- xt=self.xt_, M=self.cost_, sim=self.similarity, eta=self.reg_lap, alpha=self.reg_src,
+ xt=self.xt_, M=self.cost_, reg=self.reg, sim=self.similarity, eta=self.reg_lap, alpha=self.reg_src,
numItermax=self.max_iter, stopThr=self.tol, numInnerItermax=self.max_inner_iter,
stopInnerThr=self.inner_tol, log=self.log, verbose=self.verbose)