summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorSlasnista <stan.chambon@gmail.com>2017-08-04 13:56:51 +0200
committerSlasnista <stan.chambon@gmail.com>2017-08-04 13:56:51 +0200
commitd793f1f73e6f816458d8b307762675aa9fa84d22 (patch)
tree23e5975343c6622cc2a359258d4e15424bbfe3ea /ot/da.py
parent0b005906f9d78adbf4d52d2ea9610eb3fde96a7c (diff)
correction of semi supervised mode
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py77
1 files changed, 45 insertions, 32 deletions
diff --git a/ot/da.py b/ot/da.py
index 8294e8d..08e8a8d 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -1088,26 +1088,23 @@ class BaseTransport(BaseEstimator):
# pairwise distance
self.Cost = dist(Xs, Xt, metric=self.metric)
- if self.mode == "semisupervised":
-
- if (ys is not None) and (yt is not None):
-
- # assumes labeled source samples occupy the first rows
- # and labeled target samples occupy the first columns
- classes = np.unique(ys)
- for c in classes:
- ids = np.where(ys == c)
- idt = np.where(yt == c)
-
- # all the coefficients corresponding to a source sample
- # and a target sample with the same label gets a 0
- # transport cost
- for j in idt[0]:
- self.Cost[ids[0], j] = 0
- else:
- print("Warning: using unsupervised mode\
- \nto use semisupervised mode, please provide ys and yt")
- pass
+ if (ys is not None) and (yt is not None):
+
+ if self.limit_max != np.infty:
+ self.limit_max = self.limit_max * np.max(self.Cost)
+
+ # assumes labeled source samples occupy the first rows
+ # and labeled target samples occupy the first columns
+ classes = np.unique(ys)
+ for c in classes:
+ idx_s = np.where((ys != c) & (ys != -1))
+ idx_t = np.where(yt == c)
+
+ # all the coefficients corresponding to a source sample
+ # and a target sample :
+ # with different labels get a infinite
+ for j in idx_t[0]:
+ self.Cost[idx_s[0], j] = self.limit_max
# distribution estimation
self.mu_s = self.distribution_estimation(Xs)
@@ -1243,6 +1240,9 @@ class SinkhornTransport(BaseTransport):
Controls the verbosity of the optimization algorithm
log : int, optional (default=0)
Controls the logs of the optimization algorithm
+ limit_max: float, optional (defaul=np.infty)
+ Controls the semi supervised mode. Transport between labeled source
+ and target samples of different classes will exhibit an infinite cost
Attributes
----------
Coupling_ : the optimal coupling
@@ -1257,19 +1257,19 @@ class SinkhornTransport(BaseTransport):
26, 2013
"""
- def __init__(self, reg_e=1., mode="unsupervised", max_iter=1000,
+ def __init__(self, reg_e=1., max_iter=1000,
tol=10e-9, verbose=False, log=False,
metric="sqeuclidean",
distribution_estimation=distribution_estimation_uniform,
- out_of_sample_map='ferradans'):
+ out_of_sample_map='ferradans', limit_max=np.infty):
self.reg_e = reg_e
- self.mode = mode
self.max_iter = max_iter
self.tol = tol
self.verbose = verbose
self.log = log
self.metric = metric
+ self.limit_max = limit_max
self.distribution_estimation = distribution_estimation
self.out_of_sample_map = out_of_sample_map
@@ -1326,6 +1326,10 @@ class EMDTransport(BaseTransport):
Controls the verbosity of the optimization algorithm
log : int, optional (default=0)
Controls the logs of the optimization algorithm
+ limit_max: float, optional (default=10)
+ Controls the semi supervised mode. Transport between labeled source
+ and target samples of different classes will exhibit an infinite cost
+ (10 times the maximum value of the cost matrix)
Attributes
----------
Coupling_ : the optimal coupling
@@ -1337,15 +1341,15 @@ class EMDTransport(BaseTransport):
on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
"""
- def __init__(self, mode="unsupervised", verbose=False,
+ def __init__(self, verbose=False,
log=False, metric="sqeuclidean",
distribution_estimation=distribution_estimation_uniform,
- out_of_sample_map='ferradans'):
+ out_of_sample_map='ferradans', limit_max=10):
- self.mode = mode
self.verbose = verbose
self.log = log
self.metric = metric
+ self.limit_max = limit_max
self.distribution_estimation = distribution_estimation
self.out_of_sample_map = out_of_sample_map
@@ -1414,6 +1418,10 @@ class SinkhornLpl1Transport(BaseTransport):
Controls the verbosity of the optimization algorithm
log : int, optional (default=0)
Controls the logs of the optimization algorithm
+ limit_max: float, optional (defaul=np.infty)
+ Controls the semi supervised mode. Transport between labeled source
+ and target samples of different classes will exhibit an infinite cost
+
Attributes
----------
Coupling_ : the optimal coupling
@@ -1431,16 +1439,15 @@ class SinkhornLpl1Transport(BaseTransport):
"""
- def __init__(self, reg_e=1., reg_cl=0.1, mode="unsupervised",
+ def __init__(self, reg_e=1., reg_cl=0.1,
max_iter=10, max_inner_iter=200,
tol=10e-9, verbose=False, log=False,
metric="sqeuclidean",
distribution_estimation=distribution_estimation_uniform,
- out_of_sample_map='ferradans'):
+ out_of_sample_map='ferradans', limit_max=np.infty):
self.reg_e = reg_e
self.reg_cl = reg_cl
- self.mode = mode
self.max_iter = max_iter
self.max_inner_iter = max_inner_iter
self.tol = tol
@@ -1449,6 +1456,7 @@ class SinkhornLpl1Transport(BaseTransport):
self.metric = metric
self.distribution_estimation = distribution_estimation
self.out_of_sample_map = out_of_sample_map
+ self.limit_max = limit_max
def fit(self, Xs, ys=None, Xt=None, yt=None):
"""Build a coupling matrix from source and target sets of samples
@@ -1514,6 +1522,11 @@ class SinkhornL1l2Transport(BaseTransport):
Controls the verbosity of the optimization algorithm
log : int, optional (default=0)
Controls the logs of the optimization algorithm
+ limit_max: float, optional (default=10)
+ Controls the semi supervised mode. Transport between labeled source
+ and target samples of different classes will exhibit an infinite cost
+ (10 times the maximum value of the cost matrix)
+
Attributes
----------
Coupling_ : the optimal coupling
@@ -1531,16 +1544,15 @@ class SinkhornL1l2Transport(BaseTransport):
"""
- def __init__(self, reg_e=1., reg_cl=0.1, mode="unsupervised",
+ def __init__(self, reg_e=1., reg_cl=0.1,
max_iter=10, max_inner_iter=200,
tol=10e-9, verbose=False, log=False,
metric="sqeuclidean",
distribution_estimation=distribution_estimation_uniform,
- out_of_sample_map='ferradans'):
+ out_of_sample_map='ferradans', limit_max=10):
self.reg_e = reg_e
self.reg_cl = reg_cl
- self.mode = mode
self.max_iter = max_iter
self.max_inner_iter = max_inner_iter
self.tol = tol
@@ -1549,6 +1561,7 @@ class SinkhornL1l2Transport(BaseTransport):
self.metric = metric
self.distribution_estimation = distribution_estimation
self.out_of_sample_map = out_of_sample_map
+ self.limit_max = limit_max
def fit(self, Xs, ys=None, Xt=None, yt=None):
"""Build a coupling matrix from source and target sets of samples