summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorSamarth Mishra <samarth4149@users.noreply.github.com>2020-08-24 06:08:15 -0400
committerGitHub <noreply@github.com>2020-08-24 12:08:15 +0200
commit679ed3120da21d620b7cd9a838e073c817653864 (patch)
tree851c25cc897cbcd5d3d9a5b15baa0901452524e7 /ot/da.py
parent23db72c49465a1eeb2897d4c6dd9c189aec9cd6e (diff)
Fix to BaseTransport.transform_labels() (#208)
* Fix to BaseTransport.transform_labels() Issue #207 * Fix - forgot to commit
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/ot/da.py b/ot/da.py
index b881a8b..f1e4769 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -1111,7 +1111,7 @@ class BaseTransport(BaseEstimator):
D1 = np.zeros((n, len(ysTemp)))
# perform label propagation
- transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None]
+ transp = self.coupling_ / np.sum(self.coupling_, 0, keepdims=True)
# set nans to 0
transp[~ np.isfinite(transp)] = 0