diff options
-rw-r--r-- | ot/bregman.py | 29 | ||||
-rw-r--r-- | test/test_bregman.py | 6 |
2 files changed, 17 insertions, 18 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index f1b18f8..f6aa339 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1375,17 +1375,18 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI ''' if a is None: - a = ot.unif(np.shape(X_s)[0]) + a = utils.unif(np.shape(X_s)[0]) if b is None: - b = ot.unif(np.shape(X_t)[0]) + b = utils.unif(np.shape(X_t)[0]) + M = ot.dist(X_s, X_t, metric=metric) - if log == False: - pi = ot.sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs) - return pi - if log == True: - pi, log = ot.sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs) + if log: + pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs) return pi, log + else: + pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs) + return pi def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs): @@ -1464,18 +1465,18 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num ''' if a is None: - a = ot.unif(np.shape(X_s)[0]) + a = utils.unif(np.shape(X_s)[0]) if b is None: - b = ot.unif(np.shape(X_t)[0]) + b = utils.unif(np.shape(X_t)[0]) M = ot.dist(X_s, X_t, metric=metric) - if log == False: - sinkhorn_loss = ot.sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) - return sinkhorn_loss - if log == True: - sinkhorn_loss, log = ot.sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) + if log: + sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) return sinkhorn_loss, log + else: + sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) + return sinkhorn_loss def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs): diff --git a/test/test_bregman.py b/test/test_bregman.py index b890df1..8b001a7 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -195,13 +195,11 @@ def test_empirical_sinkhorn(): n = 100 a = ot.unif(n) b = ot.unif(n) - M = ot.dist(X_s, X_t) - M_e = ot.dist(X_s, X_t, metric='euclidean') - - rng = np.random.RandomState(0) X_s = np.reshape(np.arange(n), (n, 1)) X_t = np.reshape(np.arange(0, n), (n, 1)) + M = ot.dist(X_s, X_t) + M_e = ot.dist(X_s, X_t, metric='euclidean') G_sqe = ot.bregman.empirical_sinkhorn(X_s, X_t, 1) sinkhorn_sqe = ot.sinkhorn(a, b, M, 1) |