From d793f1f73e6f816458d8b307762675aa9fa84d22 Mon Sep 17 00:00:00 2001 From: Slasnista Date: Fri, 4 Aug 2017 13:56:51 +0200 Subject: correction of semi supervised mode --- ot/da.py | 77 +++++++++++++++++++++++++++++++++++++--------------------------- 1 file changed, 45 insertions(+), 32 deletions(-) (limited to 'ot') diff --git a/ot/da.py b/ot/da.py index 8294e8d..08e8a8d 100644 --- a/ot/da.py +++ b/ot/da.py @@ -1088,26 +1088,23 @@ class BaseTransport(BaseEstimator): # pairwise distance self.Cost = dist(Xs, Xt, metric=self.metric) - if self.mode == "semisupervised": - - if (ys is not None) and (yt is not None): - - # assumes labeled source samples occupy the first rows - # and labeled target samples occupy the first columns - classes = np.unique(ys) - for c in classes: - ids = np.where(ys == c) - idt = np.where(yt == c) - - # all the coefficients corresponding to a source sample - # and a target sample with the same label gets a 0 - # transport cost - for j in idt[0]: - self.Cost[ids[0], j] = 0 - else: - print("Warning: using unsupervised mode\ - \nto use semisupervised mode, please provide ys and yt") - pass + if (ys is not None) and (yt is not None): + + if self.limit_max != np.infty: + self.limit_max = self.limit_max * np.max(self.Cost) + + # assumes labeled source samples occupy the first rows + # and labeled target samples occupy the first columns + classes = np.unique(ys) + for c in classes: + idx_s = np.where((ys != c) & (ys != -1)) + idx_t = np.where(yt == c) + + # all the coefficients corresponding to a source sample + # and a target sample : + # with different labels get a infinite + for j in idx_t[0]: + self.Cost[idx_s[0], j] = self.limit_max # distribution estimation self.mu_s = self.distribution_estimation(Xs) @@ -1243,6 +1240,9 @@ class SinkhornTransport(BaseTransport): Controls the verbosity of the optimization algorithm log : int, optional (default=0) Controls the logs of the optimization algorithm + limit_max: float, optional (defaul=np.infty) + Controls the semi supervised mode. Transport between labeled source + and target samples of different classes will exhibit an infinite cost Attributes ---------- Coupling_ : the optimal coupling @@ -1257,19 +1257,19 @@ class SinkhornTransport(BaseTransport): 26, 2013 """ - def __init__(self, reg_e=1., mode="unsupervised", max_iter=1000, + def __init__(self, reg_e=1., max_iter=1000, tol=10e-9, verbose=False, log=False, metric="sqeuclidean", distribution_estimation=distribution_estimation_uniform, - out_of_sample_map='ferradans'): + out_of_sample_map='ferradans', limit_max=np.infty): self.reg_e = reg_e - self.mode = mode self.max_iter = max_iter self.tol = tol self.verbose = verbose self.log = log self.metric = metric + self.limit_max = limit_max self.distribution_estimation = distribution_estimation self.out_of_sample_map = out_of_sample_map @@ -1326,6 +1326,10 @@ class EMDTransport(BaseTransport): Controls the verbosity of the optimization algorithm log : int, optional (default=0) Controls the logs of the optimization algorithm + limit_max: float, optional (default=10) + Controls the semi supervised mode. Transport between labeled source + and target samples of different classes will exhibit an infinite cost + (10 times the maximum value of the cost matrix) Attributes ---------- Coupling_ : the optimal coupling @@ -1337,15 +1341,15 @@ class EMDTransport(BaseTransport): on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 """ - def __init__(self, mode="unsupervised", verbose=False, + def __init__(self, verbose=False, log=False, metric="sqeuclidean", distribution_estimation=distribution_estimation_uniform, - out_of_sample_map='ferradans'): + out_of_sample_map='ferradans', limit_max=10): - self.mode = mode self.verbose = verbose self.log = log self.metric = metric + self.limit_max = limit_max self.distribution_estimation = distribution_estimation self.out_of_sample_map = out_of_sample_map @@ -1414,6 +1418,10 @@ class SinkhornLpl1Transport(BaseTransport): Controls the verbosity of the optimization algorithm log : int, optional (default=0) Controls the logs of the optimization algorithm + limit_max: float, optional (defaul=np.infty) + Controls the semi supervised mode. Transport between labeled source + and target samples of different classes will exhibit an infinite cost + Attributes ---------- Coupling_ : the optimal coupling @@ -1431,16 +1439,15 @@ class SinkhornLpl1Transport(BaseTransport): """ - def __init__(self, reg_e=1., reg_cl=0.1, mode="unsupervised", + def __init__(self, reg_e=1., reg_cl=0.1, max_iter=10, max_inner_iter=200, tol=10e-9, verbose=False, log=False, metric="sqeuclidean", distribution_estimation=distribution_estimation_uniform, - out_of_sample_map='ferradans'): + out_of_sample_map='ferradans', limit_max=np.infty): self.reg_e = reg_e self.reg_cl = reg_cl - self.mode = mode self.max_iter = max_iter self.max_inner_iter = max_inner_iter self.tol = tol @@ -1449,6 +1456,7 @@ class SinkhornLpl1Transport(BaseTransport): self.metric = metric self.distribution_estimation = distribution_estimation self.out_of_sample_map = out_of_sample_map + self.limit_max = limit_max def fit(self, Xs, ys=None, Xt=None, yt=None): """Build a coupling matrix from source and target sets of samples @@ -1514,6 +1522,11 @@ class SinkhornL1l2Transport(BaseTransport): Controls the verbosity of the optimization algorithm log : int, optional (default=0) Controls the logs of the optimization algorithm + limit_max: float, optional (default=10) + Controls the semi supervised mode. Transport between labeled source + and target samples of different classes will exhibit an infinite cost + (10 times the maximum value of the cost matrix) + Attributes ---------- Coupling_ : the optimal coupling @@ -1531,16 +1544,15 @@ class SinkhornL1l2Transport(BaseTransport): """ - def __init__(self, reg_e=1., reg_cl=0.1, mode="unsupervised", + def __init__(self, reg_e=1., reg_cl=0.1, max_iter=10, max_inner_iter=200, tol=10e-9, verbose=False, log=False, metric="sqeuclidean", distribution_estimation=distribution_estimation_uniform, - out_of_sample_map='ferradans'): + out_of_sample_map='ferradans', limit_max=10): self.reg_e = reg_e self.reg_cl = reg_cl - self.mode = mode self.max_iter = max_iter self.max_inner_iter = max_inner_iter self.tol = tol @@ -1549,6 +1561,7 @@ class SinkhornL1l2Transport(BaseTransport): self.metric = metric self.distribution_estimation = distribution_estimation self.out_of_sample_map = out_of_sample_map + self.limit_max = limit_max def fit(self, Xs, ys=None, Xt=None, yt=None): """Build a coupling matrix from source and target sets of samples -- cgit v1.2.3