summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ot/bregman.py29
-rw-r--r--test/test_bregman.py6
2 files changed, 17 insertions, 18 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index f1b18f8..f6aa339 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -1375,17 +1375,18 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI
'''
if a is None:
- a = ot.unif(np.shape(X_s)[0])
+ a = utils.unif(np.shape(X_s)[0])
if b is None:
- b = ot.unif(np.shape(X_t)[0])
+ b = utils.unif(np.shape(X_t)[0])
+
M = ot.dist(X_s, X_t, metric=metric)
- if log == False:
- pi = ot.sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs)
- return pi
- if log == True:
- pi, log = ot.sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs)
+ if log:
+ pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs)
return pi, log
+ else:
+ pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs)
+ return pi
def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):
@@ -1464,18 +1465,18 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
'''
if a is None:
- a = ot.unif(np.shape(X_s)[0])
+ a = utils.unif(np.shape(X_s)[0])
if b is None:
- b = ot.unif(np.shape(X_t)[0])
+ b = utils.unif(np.shape(X_t)[0])
M = ot.dist(X_s, X_t, metric=metric)
- if log == False:
- sinkhorn_loss = ot.sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs)
- return sinkhorn_loss
- if log == True:
- sinkhorn_loss, log = ot.sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ if log:
+ sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs)
return sinkhorn_loss, log
+ else:
+ sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ return sinkhorn_loss
def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):
diff --git a/test/test_bregman.py b/test/test_bregman.py
index b890df1..8b001a7 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -195,13 +195,11 @@ def test_empirical_sinkhorn():
n = 100
a = ot.unif(n)
b = ot.unif(n)
- M = ot.dist(X_s, X_t)
- M_e = ot.dist(X_s, X_t, metric='euclidean')
-
- rng = np.random.RandomState(0)
X_s = np.reshape(np.arange(n), (n, 1))
X_t = np.reshape(np.arange(0, n), (n, 1))
+ M = ot.dist(X_s, X_t)
+ M_e = ot.dist(X_s, X_t, metric='euclidean')
G_sqe = ot.bregman.empirical_sinkhorn(X_s, X_t, 1)
sinkhorn_sqe = ot.sinkhorn(a, b, M, 1)