summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorievred <ievgen.redko@univ-st-etienne.fr>2020-04-20 13:55:45 +0200
committerievred <ievgen.redko@univ-st-etienne.fr>2020-04-20 13:55:45 +0200
commitfd115a538deb8fa9dcf3169fcfa6b85aebd36b07 (patch)
tree540b2a67b74672e7dbc74b1f1f7e347aa5deafd2 /ot/da.py
parent1a36193c6616b8ad89fe0e0f2a5a7ab137e9d820 (diff)
sim+sim param fixed
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py53
1 files changed, 25 insertions, 28 deletions
diff --git a/ot/da.py b/ot/da.py
index 8e26e31..300af30 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -748,7 +748,7 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
return A, b
-def emd_laplace(a, b, xs, xt, M, reg, eta, alpha,
+def emd_laplace(a, b, xs, xt, M, sim, sim_param, reg, eta, alpha,
numItermax, stopThr, numInnerItermax,
stopInnerThr, log=False, verbose=False, **kwargs):
r"""Solve the optimal transport problem (OT) with Laplacian regularization
@@ -785,6 +785,11 @@ def emd_laplace(a, b, xs, xt, M, reg, eta, alpha,
samples in the target domain
M : np.ndarray (ns,nt)
loss matrix
+ sim : string, optional
+ Type of similarity ('knn' or 'gauss') used to construct the Laplacian.
+ sim_param : int or float, optional
+ Parameter (number of the nearest neighbors for sim='knn'
+ or bandwidth for sim='gauss' used to compute the Laplacian.
reg : string
Type of Laplacian regularization
eta : float
@@ -803,11 +808,6 @@ def emd_laplace(a, b, xs, xt, M, reg, eta, alpha,
Print information along iterations
log : bool, optional
record log if True
- kwargs : dict
- Dictionary with attributes 'sim' ('knn' or 'gauss') and
- 'param' (int, float or None) for similarity type and its parameter to be used.
- If 'param' is None, it is computed as mean pairwise Euclidean distance over the data set
- or set to 3 when sim is 'gauss' or 'knn', respectively.
Returns
-------
@@ -824,7 +824,7 @@ def emd_laplace(a, b, xs, xt, M, reg, eta, alpha,
"Optimal Transport for Domain Adaptation," in IEEE
Transactions on Pattern Analysis and Machine Intelligence ,
vol.PP, no.99, pp.1-1
- .. [28] R. Flamary, N. Courty, D. Tuia, A. Rakotomamonjy,
+ .. [30] R. Flamary, N. Courty, D. Tuia, A. Rakotomamonjy,
"Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching,"
in NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014.
@@ -834,28 +834,28 @@ def emd_laplace(a, b, xs, xt, M, reg, eta, alpha,
ot.optim.cg : General regularized OT
"""
- if not isinstance(kwargs['param'], (int, float, type(None))):
+ if not isinstance(sim_param, (int, float, type(None))):
raise ValueError(
- 'Similarity parameter should be an int or a float. Got {type} instead.'.format(type=type(kwargs['param'])))
+ 'Similarity parameter should be an int or a float. Got {type} instead.'.format(type=type(sim_param).__name__))
- if kwargs['sim'] == 'gauss':
- if kwargs['param'] is None:
- kwargs['param'] = 1 / (2 * (np.mean(dist(xs, xs, 'sqeuclidean')) ** 2))
- sS = kernel(xs, xs, method=kwargs['sim'], sigma=kwargs['param'])
- sT = kernel(xt, xt, method=kwargs['sim'], sigma=kwargs['param'])
+ if sim == 'gauss':
+ if sim_param is None:
+ sim_param = 1 / (2 * (np.mean(dist(xs, xs, 'sqeuclidean')) ** 2))
+ sS = kernel(xs, xs, method=sim, sigma=sim_param)
+ sT = kernel(xt, xt, method=sim, sigma=sim_param)
- elif kwargs['sim'] == 'knn':
- if kwargs['param'] is None:
- kwargs['param'] = 3
+ elif sim == 'knn':
+ if sim_param is None:
+ sim_param = 3
from sklearn.neighbors import kneighbors_graph
- sS = kneighbors_graph(X=xs, n_neighbors=int(kwargs['param'])).toarray()
+ sS = kneighbors_graph(X=xs, n_neighbors=int(sim_param)).toarray()
sS = (sS + sS.T) / 2
- sT = kneighbors_graph(xt, n_neighbors=int(kwargs['param'])).toarray()
+ sT = kneighbors_graph(xt, n_neighbors=int(sim_param)).toarray()
sT = (sT + sT.T) / 2
else:
- raise ValueError('Unknown similarity type {sim}. Currently supported similarity types are "knn" and "gauss".'.format(sim=kwargs['sim']))
+ raise ValueError('Unknown similarity type {sim}. Currently supported similarity types are "knn" and "gauss".'.format(sim=sim))
lS = laplacian(sS)
lT = laplacian(sT)
@@ -1729,9 +1729,10 @@ class EMDLaplaceTransport(BaseTransport):
can occur with large metric values.
similarity : string, optional (default="knn")
The similarity to use either knn or gaussian
- similarity_param : int or float, optional (default=3)
+ similarity_param : int or float, optional (default=None)
Parameter for the similarity: number of nearest neighbors or bandwidth
- if similarity="knn" or "gaussian", respectively.
+ if similarity="knn" or "gaussian", respectively. If None is provided,
+ it is set to 3 or the average pairwise squared Euclidean distance, respectively.
max_iter : int, optional (default=100)
Max number of BCD iterations
tol : float, optional (default=1e-5)
@@ -1813,14 +1814,10 @@ class EMDLaplaceTransport(BaseTransport):
super(EMDLaplaceTransport, self).fit(Xs, ys, Xt, yt)
- kwargs = dict()
- kwargs["sim"] = self.similarity
- kwargs["param"] = self.sim_param
-
returned_ = emd_laplace(a=self.mu_s, b=self.mu_t, xs=self.xs_,
- xt=self.xt_, M=self.cost_, reg=self.reg, eta=self.reg_lap, alpha=self.reg_src,
+ xt=self.xt_, M=self.cost_, sim=self.similarity, sim_param=self.sim_param, reg=self.reg, eta=self.reg_lap, alpha=self.reg_src,
numItermax=self.max_iter, stopThr=self.tol, numInnerItermax=self.max_inner_iter,
- stopInnerThr=self.inner_tol, log=self.log, verbose=self.verbose, **kwargs)
+ stopInnerThr=self.inner_tol, log=self.log, verbose=self.verbose)
# coupling estimation
if self.log: