summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorSlasnista <stan.chambon@gmail.com>2017-08-04 12:04:04 +0200
committerNicolas Courty <Nico@MacBook-Pro-de-Nicolas.local>2017-09-01 11:09:13 +0200
commit4e562a1ce24119b8c9c1efb9d078762904c5d78a (patch)
tree0487c2d8c5fc3ef55524a988314ed79144eaf45d /ot
parent2005a09548a6f6d42cd9aafadbb4583e4029936c (diff)
semi supervised mode supported
Diffstat (limited to 'ot')
-rw-r--r--ot/da.py21
1 files changed, 19 insertions, 2 deletions
diff --git a/ot/da.py b/ot/da.py
index 6100d15..8294e8d 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -1089,8 +1089,25 @@ class BaseTransport(BaseEstimator):
self.Cost = dist(Xs, Xt, metric=self.metric)
if self.mode == "semisupervised":
- print("TODO: modify cost matrix accordingly")
- pass
+
+ if (ys is not None) and (yt is not None):
+
+ # assumes labeled source samples occupy the first rows
+ # and labeled target samples occupy the first columns
+ classes = np.unique(ys)
+ for c in classes:
+ ids = np.where(ys == c)
+ idt = np.where(yt == c)
+
+ # all the coefficients corresponding to a source sample
+ # and a target sample with the same label gets a 0
+ # transport cost
+ for j in idt[0]:
+ self.Cost[ids[0], j] = 0
+ else:
+ print("Warning: using unsupervised mode\
+ \nto use semisupervised mode, please provide ys and yt")
+ pass
# distribution estimation
self.mu_s = self.distribution_estimation(Xs)