summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
Diffstat (limited to 'ot')
-rw-r--r--ot/bregman.py13
1 files changed, 7 insertions, 6 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 9e9989f..f873a85 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -10,6 +10,7 @@ Bregman projections for regularized OT
# License: MIT License
import numpy as np
+from .utils import unif, dist
def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
@@ -1375,11 +1376,11 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI
'''
if a is None:
- a = ot.unif(np.shape(X_s)[0])
+ a = unif(np.shape(X_s)[0])
if b is None:
- b = ot.unif(np.shape(X_t)[0])
+ b = unif(np.shape(X_t)[0])
- M = ot.dist(X_s, X_t, metric=metric)
+ M = dist(X_s, X_t, metric=metric)
if log:
pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs)
@@ -1465,11 +1466,11 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
'''
if a is None:
- a = ot.unif(np.shape(X_s)[0])
+ a = unif(np.shape(X_s)[0])
if b is None:
- b = ot.unif(np.shape(X_t)[0])
+ b = unif(np.shape(X_t)[0])
- M = ot.dist(X_s, X_t, metric=metric)
+ M = dist(X_s, X_t, metric=metric)
if log:
sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs)