From fd115a538deb8fa9dcf3169fcfa6b85aebd36b07 Mon Sep 17 00:00:00 2001 From: ievred Date: Mon, 20 Apr 2020 13:55:45 +0200 Subject: sim+sim param fixed --- ot/da.py | 53 +++++++++++++++++++++++++---------------------------- 1 file changed, 25 insertions(+), 28 deletions(-) (limited to 'ot/da.py') diff --git a/ot/da.py b/ot/da.py index 8e26e31..300af30 100644 --- a/ot/da.py +++ b/ot/da.py @@ -748,7 +748,7 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, return A, b -def emd_laplace(a, b, xs, xt, M, reg, eta, alpha, +def emd_laplace(a, b, xs, xt, M, sim, sim_param, reg, eta, alpha, numItermax, stopThr, numInnerItermax, stopInnerThr, log=False, verbose=False, **kwargs): r"""Solve the optimal transport problem (OT) with Laplacian regularization @@ -785,6 +785,11 @@ def emd_laplace(a, b, xs, xt, M, reg, eta, alpha, samples in the target domain M : np.ndarray (ns,nt) loss matrix + sim : string, optional + Type of similarity ('knn' or 'gauss') used to construct the Laplacian. + sim_param : int or float, optional + Parameter (number of the nearest neighbors for sim='knn' + or bandwidth for sim='gauss' used to compute the Laplacian. reg : string Type of Laplacian regularization eta : float @@ -803,11 +808,6 @@ def emd_laplace(a, b, xs, xt, M, reg, eta, alpha, Print information along iterations log : bool, optional record log if True - kwargs : dict - Dictionary with attributes 'sim' ('knn' or 'gauss') and - 'param' (int, float or None) for similarity type and its parameter to be used. - If 'param' is None, it is computed as mean pairwise Euclidean distance over the data set - or set to 3 when sim is 'gauss' or 'knn', respectively. Returns ------- @@ -824,7 +824,7 @@ def emd_laplace(a, b, xs, xt, M, reg, eta, alpha, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 - .. [28] R. Flamary, N. Courty, D. Tuia, A. Rakotomamonjy, + .. [30] R. Flamary, N. Courty, D. Tuia, A. Rakotomamonjy, "Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching," in NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014. @@ -834,28 +834,28 @@ def emd_laplace(a, b, xs, xt, M, reg, eta, alpha, ot.optim.cg : General regularized OT """ - if not isinstance(kwargs['param'], (int, float, type(None))): + if not isinstance(sim_param, (int, float, type(None))): raise ValueError( - 'Similarity parameter should be an int or a float. Got {type} instead.'.format(type=type(kwargs['param']))) + 'Similarity parameter should be an int or a float. Got {type} instead.'.format(type=type(sim_param).__name__)) - if kwargs['sim'] == 'gauss': - if kwargs['param'] is None: - kwargs['param'] = 1 / (2 * (np.mean(dist(xs, xs, 'sqeuclidean')) ** 2)) - sS = kernel(xs, xs, method=kwargs['sim'], sigma=kwargs['param']) - sT = kernel(xt, xt, method=kwargs['sim'], sigma=kwargs['param']) + if sim == 'gauss': + if sim_param is None: + sim_param = 1 / (2 * (np.mean(dist(xs, xs, 'sqeuclidean')) ** 2)) + sS = kernel(xs, xs, method=sim, sigma=sim_param) + sT = kernel(xt, xt, method=sim, sigma=sim_param) - elif kwargs['sim'] == 'knn': - if kwargs['param'] is None: - kwargs['param'] = 3 + elif sim == 'knn': + if sim_param is None: + sim_param = 3 from sklearn.neighbors import kneighbors_graph - sS = kneighbors_graph(X=xs, n_neighbors=int(kwargs['param'])).toarray() + sS = kneighbors_graph(X=xs, n_neighbors=int(sim_param)).toarray() sS = (sS + sS.T) / 2 - sT = kneighbors_graph(xt, n_neighbors=int(kwargs['param'])).toarray() + sT = kneighbors_graph(xt, n_neighbors=int(sim_param)).toarray() sT = (sT + sT.T) / 2 else: - raise ValueError('Unknown similarity type {sim}. Currently supported similarity types are "knn" and "gauss".'.format(sim=kwargs['sim'])) + raise ValueError('Unknown similarity type {sim}. Currently supported similarity types are "knn" and "gauss".'.format(sim=sim)) lS = laplacian(sS) lT = laplacian(sT) @@ -1729,9 +1729,10 @@ class EMDLaplaceTransport(BaseTransport): can occur with large metric values. similarity : string, optional (default="knn") The similarity to use either knn or gaussian - similarity_param : int or float, optional (default=3) + similarity_param : int or float, optional (default=None) Parameter for the similarity: number of nearest neighbors or bandwidth - if similarity="knn" or "gaussian", respectively. + if similarity="knn" or "gaussian", respectively. If None is provided, + it is set to 3 or the average pairwise squared Euclidean distance, respectively. max_iter : int, optional (default=100) Max number of BCD iterations tol : float, optional (default=1e-5) @@ -1813,14 +1814,10 @@ class EMDLaplaceTransport(BaseTransport): super(EMDLaplaceTransport, self).fit(Xs, ys, Xt, yt) - kwargs = dict() - kwargs["sim"] = self.similarity - kwargs["param"] = self.sim_param - returned_ = emd_laplace(a=self.mu_s, b=self.mu_t, xs=self.xs_, - xt=self.xt_, M=self.cost_, reg=self.reg, eta=self.reg_lap, alpha=self.reg_src, + xt=self.xt_, M=self.cost_, sim=self.similarity, sim_param=self.sim_param, reg=self.reg, eta=self.reg_lap, alpha=self.reg_src, numItermax=self.max_iter, stopThr=self.tol, numInnerItermax=self.max_inner_iter, - stopInnerThr=self.inner_tol, log=self.log, verbose=self.verbose, **kwargs) + stopInnerThr=self.inner_tol, log=self.log, verbose=self.verbose) # coupling estimation if self.log: -- cgit v1.2.3