summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorSlasnista <stan.chambon@gmail.com>2017-08-04 12:04:04 +0200
committerSlasnista <stan.chambon@gmail.com>2017-08-04 12:04:04 +0200
commit0b005906f9d78adbf4d52d2ea9610eb3fde96a7c (patch)
tree0487c2d8c5fc3ef55524a988314ed79144eaf45d /ot/da.py
parent727077ad7db503955aea0751abf9f361f1d82af7 (diff)
semi supervised mode supported
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py21
1 files changed, 19 insertions, 2 deletions
diff --git a/ot/da.py b/ot/da.py
index 6100d15..8294e8d 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -1089,8 +1089,25 @@ class BaseTransport(BaseEstimator):
self.Cost = dist(Xs, Xt, metric=self.metric)
if self.mode == "semisupervised":
- print("TODO: modify cost matrix accordingly")
- pass
+
+ if (ys is not None) and (yt is not None):
+
+ # assumes labeled source samples occupy the first rows
+ # and labeled target samples occupy the first columns
+ classes = np.unique(ys)
+ for c in classes:
+ ids = np.where(ys == c)
+ idt = np.where(yt == c)
+
+ # all the coefficients corresponding to a source sample
+ # and a target sample with the same label gets a 0
+ # transport cost
+ for j in idt[0]:
+ self.Cost[ids[0], j] = 0
+ else:
+ print("Warning: using unsupervised mode\
+ \nto use semisupervised mode, please provide ys and yt")
+ pass
# distribution estimation
self.mu_s = self.distribution_estimation(Xs)