diff options
author | Slasnista <stan.chambon@gmail.com> | 2017-08-04 12:04:04 +0200 |
---|---|---|
committer | Slasnista <stan.chambon@gmail.com> | 2017-08-04 12:04:04 +0200 |
commit | 0b005906f9d78adbf4d52d2ea9610eb3fde96a7c (patch) | |
tree | 0487c2d8c5fc3ef55524a988314ed79144eaf45d /ot/da.py | |
parent | 727077ad7db503955aea0751abf9f361f1d82af7 (diff) |
semi supervised mode supported
Diffstat (limited to 'ot/da.py')
-rw-r--r-- | ot/da.py | 21 |
1 files changed, 19 insertions, 2 deletions
@@ -1089,8 +1089,25 @@ class BaseTransport(BaseEstimator): self.Cost = dist(Xs, Xt, metric=self.metric) if self.mode == "semisupervised": - print("TODO: modify cost matrix accordingly") - pass + + if (ys is not None) and (yt is not None): + + # assumes labeled source samples occupy the first rows + # and labeled target samples occupy the first columns + classes = np.unique(ys) + for c in classes: + ids = np.where(ys == c) + idt = np.where(yt == c) + + # all the coefficients corresponding to a source sample + # and a target sample with the same label gets a 0 + # transport cost + for j in idt[0]: + self.Cost[ids[0], j] = 0 + else: + print("Warning: using unsupervised mode\ + \nto use semisupervised mode, please provide ys and yt") + pass # distribution estimation self.mu_s = self.distribution_estimation(Xs) |