diff options
Diffstat (limited to 'ot')
-rw-r--r-- | ot/bregman.py | 29 | ||||
-rw-r--r-- | ot/stochastic.py | 3 |
2 files changed, 21 insertions, 11 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index f873a85..47554fb 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1317,9 +1317,9 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI \gamma\geq 0 where : - - M is the (ns,nt) metric cost matrix + - :math:`M` is the (ns,nt) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights (sum to 1) + - :math:`a` and :math:`b` are source and target weights (sum to 1) Parameters @@ -1399,7 +1399,7 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num The function solves the following optimization problem: .. math:: - W = \min_\gamma_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) s.t. \gamma 1 = a @@ -1408,9 +1408,9 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num \gamma\geq 0 where : - - M is the (ns,nt) metric cost matrix + - :math:`M` is the (ns,nt) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights (sum to 1) + - :math:`a` and :math:`b` are source and target weights (sum to 1) Parameters @@ -1484,13 +1484,20 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli ''' Compute the sinkhorn divergence loss from empirical data - The function solves the following optimization problem: + The function solves the following optimization problems and return the + sinkhorn divergence :math:`S`: .. math:: - S = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) - - \min_\gamma_a <\gamma_a,M_a>_F + reg\cdot\Omega(\gamma_a) - - \min_\gamma_b <\gamma_b,M_b>_F + reg\cdot\Omega(\gamma_b) + W &= \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + + W_a &= \min_{\gamma_a} <\gamma_a,M_a>_F + reg\cdot\Omega(\gamma_a) + + W_b &= \min_{\gamma_b} <\gamma_b,M_b>_F + reg\cdot\Omega(\gamma_b) + + S &= W - 1/2 * (W_a + W_b) + + .. math:: s.t. \gamma 1 = a \gamma^T 1= b @@ -1510,9 +1517,9 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli \gamma_b\geq 0 where : - - M (resp. :math:`M_a, M_b) is the (ns,nt) metric cost matrix (resp (ns, ns) and (nt, nt)) + - :math:`M` (resp. :math:`M_a, M_b`) is the (ns,nt) metric cost matrix (resp (ns, ns) and (nt, nt)) - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights (sum to 1) + - :math:`a` and :math:`b` are source and target weights (sum to 1) Parameters diff --git a/ot/stochastic.py b/ot/stochastic.py index 0db39c8..85c4230 100644 --- a/ot/stochastic.py +++ b/ot/stochastic.py @@ -348,8 +348,11 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None, .. math:: \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + s.t. \gamma 1 = a + \gamma^T 1= b + \gamma \geq 0 Where : |