summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ot/da.py75
-rw-r--r--ot/utils.py22
-rw-r--r--test/test_da.py1
3 files changed, 60 insertions, 38 deletions
diff --git a/ot/da.py b/ot/da.py
index 29b0a8b..4318c0d 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -16,7 +16,7 @@ import scipy.linalg as linalg
from .bregman import sinkhorn, jcpot_barycenter
from .lp import emd
-from .utils import unif, dist, kernel, cost_normalization
+from .utils import unif, dist, kernel, cost_normalization, label_normalization
from .utils import check_params, BaseEstimator
from .unbalanced import sinkhorn_unbalanced
from .optim import cg
@@ -786,6 +786,9 @@ class BaseTransport(BaseEstimator):
transform method should always get as input a Xs parameter
inverse_transform method should always get as input a Xt parameter
+
+ transform_labels method should always get as input a ys parameter
+ inverse_transform_labels method should always get as input a yt parameter
"""
def fit(self, Xs=None, ys=None, Xt=None, yt=None):
@@ -944,7 +947,7 @@ class BaseTransport(BaseEstimator):
return transp_Xs
def transform_labels(self, ys=None):
- """Propagate source labels ys to obtain estimated target labels
+ """Propagate source labels ys to obtain estimated target labels as in [27]
Parameters
----------
@@ -955,14 +958,23 @@ class BaseTransport(BaseEstimator):
-------
transp_ys : array-like, shape (n_target_samples,)
Estimated target labels.
+
+ References
+ ----------
+
+ .. [27] Ievgen Redko, Nicolas Courty, RĂ©mi Flamary, Devis Tuia
+ "Optimal transport for multi-source domain adaptation under target shift",
+ International Conference on Artificial Intelligence and Statistics (AISTATS), 2019.
+
"""
# check the necessary inputs parameters are here
if check_params(ys=ys):
- classes = np.unique(ys)
+ ysTemp = label_normalization(np.copy(ys))
+ classes = np.unique(ysTemp)
n = len(classes)
- D1 = np.zeros((n, len(ys)))
+ D1 = np.zeros((n, len(ysTemp)))
# perform label propagation
transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None]
@@ -970,18 +982,13 @@ class BaseTransport(BaseEstimator):
# set nans to 0
transp[~ np.isfinite(transp)] = 0
- if np.min(classes) != 0:
- ys = ys - np.min(classes)
- classes = np.unique(ys)
-
for c in classes:
- D1[int(c), ys == c] = 1
+ D1[int(c), ysTemp == c] = 1
# compute transported samples
transp_ys = np.dot(D1, transp)
- return np.argmax(transp_ys,axis=0)
-
+ return np.argmax(transp_ys, axis=0)
def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
batch_size=128):
@@ -1066,9 +1073,10 @@ class BaseTransport(BaseEstimator):
# check the necessary inputs parameters are here
if check_params(yt=yt):
- classes = np.unique(yt)
+ ytTemp = label_normalization(np.copy(yt))
+ classes = np.unique(ytTemp)
n = len(classes)
- D1 = np.zeros((n, len(yt)))
+ D1 = np.zeros((n, len(ytTemp)))
# perform label propagation
transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None]
@@ -1076,17 +1084,13 @@ class BaseTransport(BaseEstimator):
# set nans to 0
transp[~ np.isfinite(transp)] = 0
- if np.min(classes) != 0:
- yt = yt - np.min(classes)
- classes = np.unique(yt)
-
for c in classes:
- D1[int(c), yt == c] = 1
+ D1[int(c), ytTemp == c] = 1
# compute transported samples
transp_ys = np.dot(D1, transp.T)
- return np.argmax(transp_ys,axis=0)
+ return np.argmax(transp_ys, axis=0)
class LinearTransport(BaseTransport):
@@ -2163,7 +2167,7 @@ class JCPOTTransport(BaseTransport):
return transp_Xs
def transform_labels(self, ys=None):
- """Propagate source labels ys to obtain target labels
+ """Propagate source labels ys to obtain target labels as in [27]
Parameters
----------
@@ -2178,11 +2182,12 @@ class JCPOTTransport(BaseTransport):
# check the necessary inputs parameters are here
if check_params(ys=ys):
- yt = np.zeros((len(np.unique(np.concatenate(ys))),self.xt_.shape[0]))
+ yt = np.zeros((len(np.unique(np.concatenate(ys))), self.xt_.shape[0]))
for i in range(len(ys)):
- classes = np.unique(ys[i])
+ ysTemp = label_normalization(np.copy(ys[i]))
+ classes = np.unique(ysTemp)
n = len(classes)
- ns = len(ys[i])
+ ns = len(ysTemp)
# perform label propagation
transp = self.coupling_[i] / np.sum(self.coupling_[i], 1)[:, None]
@@ -2195,16 +2200,13 @@ class JCPOTTransport(BaseTransport):
else:
D1 = np.zeros((n, ns))
- if np.min(classes) != 0:
- ys = ys - np.min(classes)
- classes = np.unique(ys)
-
for c in classes:
- D1[int(c), ys == c] = 1
+ D1[int(c), ysTemp == c] = 1
+
# compute transported samples
- yt = yt + np.dot(D1, transp)/len(ys)
+ yt = yt + np.dot(D1, transp) / len(ys)
- return np.argmax(yt,axis=0)
+ return np.argmax(yt, axis=0)
def inverse_transform_labels(self, yt=None):
"""Propagate source labels ys to obtain target labels
@@ -2223,16 +2225,13 @@ class JCPOTTransport(BaseTransport):
# check the necessary inputs parameters are here
if check_params(yt=yt):
transp_ys = []
- classes = np.unique(yt)
+ ytTemp = label_normalization(np.copy(yt))
+ classes = np.unique(ytTemp)
n = len(classes)
- D1 = np.zeros((n, len(yt)))
-
- if np.min(classes) != 0:
- yt = yt - np.min(classes)
- classes = np.unique(yt)
+ D1 = np.zeros((n, len(ytTemp)))
for c in classes:
- D1[int(c), yt == c] = 1
+ D1[int(c), ytTemp == c] = 1
for i in range(len(self.xs_)):
@@ -2243,6 +2242,6 @@ class JCPOTTransport(BaseTransport):
transp[~ np.isfinite(transp)] = 0
# compute transported labels
- transp_ys.append(np.argmax(np.dot(D1, transp.T),axis=0))
+ transp_ys.append(np.argmax(np.dot(D1, transp.T), axis=0))
return transp_ys
diff --git a/ot/utils.py b/ot/utils.py
index b71458b..c154f99 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -200,6 +200,28 @@ def dots(*args):
return reduce(np.dot, args)
+def label_normalization(y, start=0):
+ """ Transform labels to start at a given value
+
+ Parameters
+ ----------
+ y : array-like, shape (n, )
+ The vector of labels to be normalized.
+ start : int
+ Desired value for the smallest label in y (default=0)
+
+ Returns
+ -------
+ y : array-like, shape (n1, )
+ The input vector of labels normalized according to given start value.
+ """
+
+ diff = np.min(np.unique(y)) - start
+ if diff != 0:
+ y -= diff
+ return y
+
+
def fun(f, q_in, q_out):
""" Utility function for parmap with no serializing problems """
while True:
diff --git a/test/test_da.py b/test/test_da.py
index 4eb6de0..d96046d 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -650,6 +650,7 @@ def test_jcpot_transport_class():
transp_ys = otda.inverse_transform_labels(yt)
[assert_equal(x.shape, y.shape) for x, y in zip(transp_ys, ys)]
+
def test_jcpot_barycenter():
"""test_jcpot_barycenter
"""