summaryrefslogtreecommitdiff
path: root/ot/utils.py
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2020-04-15 13:16:30 +0200
committerGitHub <noreply@github.com>2020-04-15 13:16:30 +0200
commitadc5570550676b63b9aabb2205a67c5b7c9187f3 (patch)
tree0082b2ea3843bb50738eb4689fb1eb9c74b85034 /ot/utils.py
parent4cd4e09f89fe6f95a07d632365612b797ab760da (diff)
parent7889484b79a425ebf3632444547a6092e814bf20 (diff)
Merge pull request #137 from ievred/jcpot
[MRG] Jcpot : Multi source DA with target shift
Diffstat (limited to 'ot/utils.py')
-rw-r--r--ot/utils.py22
1 files changed, 22 insertions, 0 deletions
diff --git a/ot/utils.py b/ot/utils.py
index b71458b..c154f99 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -200,6 +200,28 @@ def dots(*args):
return reduce(np.dot, args)
+def label_normalization(y, start=0):
+ """ Transform labels to start at a given value
+
+ Parameters
+ ----------
+ y : array-like, shape (n, )
+ The vector of labels to be normalized.
+ start : int
+ Desired value for the smallest label in y (default=0)
+
+ Returns
+ -------
+ y : array-like, shape (n1, )
+ The input vector of labels normalized according to given start value.
+ """
+
+ diff = np.min(np.unique(y)) - start
+ if diff != 0:
+ y -= diff
+ return y
+
+
def fun(f, q_in, q_out):
""" Utility function for parmap with no serializing problems """
while True: