From 679ed3120da21d620b7cd9a838e073c817653864 Mon Sep 17 00:00:00 2001 From: Samarth Mishra Date: Mon, 24 Aug 2020 06:08:15 -0400 Subject: Fix to BaseTransport.transform_labels() (#208) * Fix to BaseTransport.transform_labels() Issue #207 * Fix - forgot to commit --- ot/da.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'ot/da.py') diff --git a/ot/da.py b/ot/da.py index b881a8b..f1e4769 100644 --- a/ot/da.py +++ b/ot/da.py @@ -1111,7 +1111,7 @@ class BaseTransport(BaseEstimator): D1 = np.zeros((n, len(ysTemp))) # perform label propagation - transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None] + transp = self.coupling_ / np.sum(self.coupling_, 0, keepdims=True) # set nans to 0 transp[~ np.isfinite(transp)] = 0 -- cgit v1.2.3