From 0316d552fa6005aaf0f6231eb9ca20441d5a2532 Mon Sep 17 00:00:00 2001 From: aje Date: Wed, 30 Aug 2017 10:53:31 +0200 Subject: Move normalize function in utils.py --- ot/da.py | 52 ++++++---------------------------------------------- 1 file changed, 6 insertions(+), 46 deletions(-) (limited to 'ot/da.py') diff --git a/ot/da.py b/ot/da.py index b4a69b1..61a3ba0 100644 --- a/ot/da.py +++ b/ot/da.py @@ -13,7 +13,7 @@ import numpy as np from .bregman import sinkhorn from .lp import emd -from .utils import unif, dist, kernel +from .utils import unif, dist, kernel, cost_normalization from .utils import check_params, deprecated, BaseEstimator from .optim import cg from .optim import gcg @@ -673,7 +673,7 @@ class OTDA(object): self.wt = wt self.M = dist(xs, xt, metric=self.metric) - self.normalizeM(norm) + self.M = cost_normalization(self.M, norm) self.G = emd(ws, wt, self.M, max_iter) self.computed = True @@ -741,26 +741,6 @@ class OTDA(object): # aply the delta to the interpolation return xf[idx, :] + x - x0[idx, :] - def normalizeM(self, norm): - """ Apply normalization to the loss matrix - - - Parameters - ---------- - norm : str - type of normalization from 'median','max','log','loglog' - - """ - - if norm == "median": - self.M /= float(np.median(self.M)) - elif norm == "max": - self.M /= float(np.max(self.M)) - elif norm == "log": - self.M = np.log(1 + self.M) - elif norm == "loglog": - self.M = np.log(1 + np.log(1 + self.M)) - @deprecated("The class OTDA_sinkhorn is deprecated in 0.3.1 and will be" " removed in 0.5 \nUse class SinkhornTransport instead.") @@ -787,7 +767,7 @@ class OTDA_sinkhorn(OTDA): self.wt = wt self.M = dist(xs, xt, metric=self.metric) - self.normalizeM(norm) + self.M = cost_normalization(self.M, norm) self.G = sinkhorn(ws, wt, self.M, reg, **kwargs) self.computed = True @@ -816,7 +796,7 @@ class OTDA_lpl1(OTDA): self.wt = wt self.M = dist(xs, xt, metric=self.metric) - self.normalizeM(norm) + self.M = cost_normalization(self.M, norm) self.G = sinkhorn_lpl1_mm(ws, ys, wt, self.M, reg, eta, **kwargs) self.computed = True @@ -845,7 +825,7 @@ class OTDA_l1l2(OTDA): self.wt = wt self.M = dist(xs, xt, metric=self.metric) - self.normalizeM(norm) + self.M = cost_normalization(self.M, norm) self.G = sinkhorn_l1l2_gl(ws, ys, wt, self.M, reg, eta, **kwargs) self.computed = True @@ -1001,7 +981,7 @@ class BaseTransport(BaseEstimator): # pairwise distance self.cost_ = dist(Xs, Xt, metric=self.metric) - self.normalizeCost_(self.norm) + self.cost_ = cost_normalization(self.cost_, self.norm) if (ys is not None) and (yt is not None): @@ -1183,26 +1163,6 @@ class BaseTransport(BaseEstimator): return transp_Xt - def normalizeCost_(self, norm): - """ Apply normalization to the loss matrix - - - Parameters - ---------- - norm : str - type of normalization from 'median','max','log','loglog' - - """ - - if norm == "median": - self.cost_ /= float(np.median(self.cost_)) - elif norm == "max": - self.cost_ /= float(np.max(self.cost_)) - elif norm == "log": - self.cost_ = np.log(1 + self.cost_) - elif norm == "loglog": - self.cost_ = np.log(1 + np.log(1 + self.cost_)) - class SinkhornTransport(BaseTransport): """Domain Adapatation OT method based on Sinkhorn Algorithm -- cgit v1.2.3