diff options
author | Slasnista <stan.chambon@gmail.com> | 2017-08-04 12:04:04 +0200 |
---|---|---|
committer | Nicolas Courty <Nico@MacBook-Pro-de-Nicolas.local> | 2017-09-01 11:09:13 +0200 |
commit | 4e562a1ce24119b8c9c1efb9d078762904c5d78a (patch) | |
tree | 0487c2d8c5fc3ef55524a988314ed79144eaf45d /ot/da.py | |
parent | 2005a09548a6f6d42cd9aafadbb4583e4029936c (diff) |
semi supervised mode supported
Diffstat (limited to 'ot/da.py')
-rw-r--r-- | ot/da.py | 21 |
1 files changed, 19 insertions, 2 deletions
@@ -1089,8 +1089,25 @@ class BaseTransport(BaseEstimator): self.Cost = dist(Xs, Xt, metric=self.metric) if self.mode == "semisupervised": - print("TODO: modify cost matrix accordingly") - pass + + if (ys is not None) and (yt is not None): + + # assumes labeled source samples occupy the first rows + # and labeled target samples occupy the first columns + classes = np.unique(ys) + for c in classes: + ids = np.where(ys == c) + idt = np.where(yt == c) + + # all the coefficients corresponding to a source sample + # and a target sample with the same label gets a 0 + # transport cost + for j in idt[0]: + self.Cost[ids[0], j] = 0 + else: + print("Warning: using unsupervised mode\ + \nto use semisupervised mode, please provide ys and yt") + pass # distribution estimation self.mu_s = self.distribution_estimation(Xs) |