summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ot/da.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/ot/da.py b/ot/da.py
index b881a8b..f1e4769 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -1111,7 +1111,7 @@ class BaseTransport(BaseEstimator):
D1 = np.zeros((n, len(ysTemp)))
# perform label propagation
- transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None]
+ transp = self.coupling_ / np.sum(self.coupling_, 0, keepdims=True)
# set nans to 0
transp[~ np.isfinite(transp)] = 0