From 07463285317ed5c989040edcefb5c0e8cd3ac034 Mon Sep 17 00:00:00 2001 From: ievred Date: Mon, 20 Apr 2020 10:39:13 +0200 Subject: added kwargs to sim + doc --- ot/da.py | 48 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 32 insertions(+), 16 deletions(-) (limited to 'ot/da.py') diff --git a/ot/da.py b/ot/da.py index 9e00dce..8e26e31 100644 --- a/ot/da.py +++ b/ot/da.py @@ -748,7 +748,7 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, return A, b -def emd_laplace(a, b, xs, xt, M, sim, reg, eta, alpha, +def emd_laplace(a, b, xs, xt, M, reg, eta, alpha, numItermax, stopThr, numInnerItermax, stopInnerThr, log=False, verbose=False, **kwargs): r"""Solve the optimal transport problem (OT) with Laplacian regularization @@ -803,7 +803,11 @@ def emd_laplace(a, b, xs, xt, M, sim, reg, eta, alpha, Print information along iterations log : bool, optional record log if True - + kwargs : dict + Dictionary with attributes 'sim' ('knn' or 'gauss') and + 'param' (int, float or None) for similarity type and its parameter to be used. + If 'param' is None, it is computed as mean pairwise Euclidean distance over the data set + or set to 3 when sim is 'gauss' or 'knn', respectively. Returns ------- @@ -830,24 +834,28 @@ def emd_laplace(a, b, xs, xt, M, sim, reg, eta, alpha, ot.optim.cg : General regularized OT """ - if sim == 'gauss': - if 'rbfparam' not in kwargs: - kwargs['rbfparam'] = 1 / (2 * (np.mean(dist(xs, xs, 'sqeuclidean')) ** 2)) - sS = kernel(xs, xs, method=kwargs['sim'], sigma=kwargs['rbfparam']) - sT = kernel(xt, xt, method=kwargs['sim'], sigma=kwargs['rbfparam']) + if not isinstance(kwargs['param'], (int, float, type(None))): + raise ValueError( + 'Similarity parameter should be an int or a float. Got {type} instead.'.format(type=type(kwargs['param']))) + + if kwargs['sim'] == 'gauss': + if kwargs['param'] is None: + kwargs['param'] = 1 / (2 * (np.mean(dist(xs, xs, 'sqeuclidean')) ** 2)) + sS = kernel(xs, xs, method=kwargs['sim'], sigma=kwargs['param']) + sT = kernel(xt, xt, method=kwargs['sim'], sigma=kwargs['param']) - elif sim == 'knn': - if 'nn' not in kwargs: - kwargs['nn'] = 5 + elif kwargs['sim'] == 'knn': + if kwargs['param'] is None: + kwargs['param'] = 3 from sklearn.neighbors import kneighbors_graph - sS = kneighbors_graph(xs, kwargs['nn']).toarray() + sS = kneighbors_graph(X=xs, n_neighbors=int(kwargs['param'])).toarray() sS = (sS + sS.T) / 2 - sT = kneighbors_graph(xt, kwargs['nn']).toarray() + sT = kneighbors_graph(xt, n_neighbors=int(kwargs['param'])).toarray() sT = (sT + sT.T) / 2 else: - raise ValueError('Unknown similarity type {sim}. Currently supported similarity types are "knn" and "gauss".'.format(sim=sim)) + raise ValueError('Unknown similarity type {sim}. Currently supported similarity types are "knn" and "gauss".'.format(sim=kwargs['sim'])) lS = laplacian(sS) lT = laplacian(sT) @@ -1721,6 +1729,9 @@ class EMDLaplaceTransport(BaseTransport): can occur with large metric values. similarity : string, optional (default="knn") The similarity to use either knn or gaussian + similarity_param : int or float, optional (default=3) + Parameter for the similarity: number of nearest neighbors or bandwidth + if similarity="knn" or "gaussian", respectively. max_iter : int, optional (default=100) Max number of BCD iterations tol : float, optional (default=1e-5) @@ -1754,7 +1765,7 @@ class EMDLaplaceTransport(BaseTransport): """ def __init__(self, reg_type='pos', reg_lap=1., reg_src=1., alpha=0.5, - metric="sqeuclidean", norm=None, similarity="knn", max_iter=100, tol=1e-9, + metric="sqeuclidean", norm=None, similarity="knn", similarity_param=None, max_iter=100, tol=1e-9, max_inner_iter=100000, inner_tol=1e-9, log=False, verbose=False, distribution_estimation=distribution_estimation_uniform, out_of_sample_map='ferradans'): @@ -1765,6 +1776,7 @@ class EMDLaplaceTransport(BaseTransport): self.metric = metric self.norm = norm self.similarity = similarity + self.sim_param = similarity_param self.max_iter = max_iter self.tol = tol self.max_inner_iter = max_inner_iter @@ -1801,10 +1813,14 @@ class EMDLaplaceTransport(BaseTransport): super(EMDLaplaceTransport, self).fit(Xs, ys, Xt, yt) + kwargs = dict() + kwargs["sim"] = self.similarity + kwargs["param"] = self.sim_param + returned_ = emd_laplace(a=self.mu_s, b=self.mu_t, xs=self.xs_, - xt=self.xt_, M=self.cost_, reg=self.reg, sim=self.similarity, eta=self.reg_lap, alpha=self.reg_src, + xt=self.xt_, M=self.cost_, reg=self.reg, eta=self.reg_lap, alpha=self.reg_src, numItermax=self.max_iter, stopThr=self.tol, numInnerItermax=self.max_inner_iter, - stopInnerThr=self.inner_tol, log=self.log, verbose=self.verbose) + stopInnerThr=self.inner_tol, log=self.log, verbose=self.verbose, **kwargs) # coupling estimation if self.log: -- cgit v1.2.3