summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authoraje <leo_g_autheron@hotmail.fr>2017-08-30 10:53:31 +0200
committeraje <leo_g_autheron@hotmail.fr>2017-08-30 10:53:31 +0200
commit0316d552fa6005aaf0f6231eb9ca20441d5a2532 (patch)
tree5faf38ccb37e8476bd9b5afd807fa3de69d7d6c8
parent5bbea9c608e66b499f9858af408cc65c07cf4ac2 (diff)
Move normalize function in utils.py
-rw-r--r--ot/da.py52
-rw-r--r--ot/utils.py33
2 files changed, 39 insertions, 46 deletions
diff --git a/ot/da.py b/ot/da.py
index b4a69b1..61a3ba0 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -13,7 +13,7 @@ import numpy as np
from .bregman import sinkhorn
from .lp import emd
-from .utils import unif, dist, kernel
+from .utils import unif, dist, kernel, cost_normalization
from .utils import check_params, deprecated, BaseEstimator
from .optim import cg
from .optim import gcg
@@ -673,7 +673,7 @@ class OTDA(object):
self.wt = wt
self.M = dist(xs, xt, metric=self.metric)
- self.normalizeM(norm)
+ self.M = cost_normalization(self.M, norm)
self.G = emd(ws, wt, self.M, max_iter)
self.computed = True
@@ -741,26 +741,6 @@ class OTDA(object):
# aply the delta to the interpolation
return xf[idx, :] + x - x0[idx, :]
- def normalizeM(self, norm):
- """ Apply normalization to the loss matrix
-
-
- Parameters
- ----------
- norm : str
- type of normalization from 'median','max','log','loglog'
-
- """
-
- if norm == "median":
- self.M /= float(np.median(self.M))
- elif norm == "max":
- self.M /= float(np.max(self.M))
- elif norm == "log":
- self.M = np.log(1 + self.M)
- elif norm == "loglog":
- self.M = np.log(1 + np.log(1 + self.M))
-
@deprecated("The class OTDA_sinkhorn is deprecated in 0.3.1 and will be"
" removed in 0.5 \nUse class SinkhornTransport instead.")
@@ -787,7 +767,7 @@ class OTDA_sinkhorn(OTDA):
self.wt = wt
self.M = dist(xs, xt, metric=self.metric)
- self.normalizeM(norm)
+ self.M = cost_normalization(self.M, norm)
self.G = sinkhorn(ws, wt, self.M, reg, **kwargs)
self.computed = True
@@ -816,7 +796,7 @@ class OTDA_lpl1(OTDA):
self.wt = wt
self.M = dist(xs, xt, metric=self.metric)
- self.normalizeM(norm)
+ self.M = cost_normalization(self.M, norm)
self.G = sinkhorn_lpl1_mm(ws, ys, wt, self.M, reg, eta, **kwargs)
self.computed = True
@@ -845,7 +825,7 @@ class OTDA_l1l2(OTDA):
self.wt = wt
self.M = dist(xs, xt, metric=self.metric)
- self.normalizeM(norm)
+ self.M = cost_normalization(self.M, norm)
self.G = sinkhorn_l1l2_gl(ws, ys, wt, self.M, reg, eta, **kwargs)
self.computed = True
@@ -1001,7 +981,7 @@ class BaseTransport(BaseEstimator):
# pairwise distance
self.cost_ = dist(Xs, Xt, metric=self.metric)
- self.normalizeCost_(self.norm)
+ self.cost_ = cost_normalization(self.cost_, self.norm)
if (ys is not None) and (yt is not None):
@@ -1183,26 +1163,6 @@ class BaseTransport(BaseEstimator):
return transp_Xt
- def normalizeCost_(self, norm):
- """ Apply normalization to the loss matrix
-
-
- Parameters
- ----------
- norm : str
- type of normalization from 'median','max','log','loglog'
-
- """
-
- if norm == "median":
- self.cost_ /= float(np.median(self.cost_))
- elif norm == "max":
- self.cost_ /= float(np.max(self.cost_))
- elif norm == "log":
- self.cost_ = np.log(1 + self.cost_)
- elif norm == "loglog":
- self.cost_ = np.log(1 + np.log(1 + self.cost_))
-
class SinkhornTransport(BaseTransport):
"""Domain Adapatation OT method based on Sinkhorn Algorithm
diff --git a/ot/utils.py b/ot/utils.py
index 01f2a67..31a002b 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -134,6 +134,39 @@ def dist0(n, method='lin_square'):
return res
+def cost_normalization(C, norm=None):
+ """ Apply normalization to the loss matrix
+
+
+ Parameters
+ ----------
+ C : np.array (n1, n2)
+ The cost matrix to normalize.
+ norm : str
+ type of normalization from 'median','max','log','loglog'. Any other
+ value do not normalize.
+
+
+ Returns
+ -------
+
+ C : np.array (n1, n2)
+ The input cost matrix normalized according to given norm.
+
+ """
+
+ if norm == "median":
+ C /= float(np.median(C))
+ elif norm == "max":
+ C /= float(np.max(C))
+ elif norm == "log":
+ C = np.log(1 + C)
+ elif norm == "loglog":
+ C = np.log(1 + np.log(1 + C))
+
+ return C
+
+
def dots(*args):
""" dots function for multiple matrix multiply """
return reduce(np.dot, args)