diff options
-rw-r--r-- | ot/bregman.py | 13 |
1 files changed, 7 insertions, 6 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index 9e9989f..f873a85 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -10,6 +10,7 @@ Bregman projections for regularized OT # License: MIT License import numpy as np +from .utils import unif, dist def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, @@ -1375,11 +1376,11 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI ''' if a is None: - a = ot.unif(np.shape(X_s)[0]) + a = unif(np.shape(X_s)[0]) if b is None: - b = ot.unif(np.shape(X_t)[0]) + b = unif(np.shape(X_t)[0]) - M = ot.dist(X_s, X_t, metric=metric) + M = dist(X_s, X_t, metric=metric) if log: pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs) @@ -1465,11 +1466,11 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num ''' if a is None: - a = ot.unif(np.shape(X_s)[0]) + a = unif(np.shape(X_s)[0]) if b is None: - b = ot.unif(np.shape(X_t)[0]) + b = unif(np.shape(X_t)[0]) - M = ot.dist(X_s, X_t, metric=metric) + M = dist(X_s, X_t, metric=metric) if log: sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) |