summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorLaetitia Chapel <laetitia.chapel@univ-ubs.fr>2020-04-09 14:14:34 +0200
committerGitHub <noreply@github.com>2020-04-09 14:14:34 +0200
commitfff2463aafd58343c8bc2ed7875622e16a8c1cee (patch)
treeb23efef253c4cc42c13bf3f7aad671f27bf43a3d /ot
parent9f63ee92e281427ab3d520f75bb9c3406b547365 (diff)
parent4cd4e09f89fe6f95a07d632365612b797ab760da (diff)
Merge branch 'master' into partial-W-and-GW
Diffstat (limited to 'ot')
-rw-r--r--ot/bregman.py8
-rw-r--r--ot/gromov.py48
-rw-r--r--ot/lp/__init__.py21
3 files changed, 42 insertions, 35 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 2707b7c..d5e3563 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -9,6 +9,7 @@ Bregman projections for regularized OT
# Titouan Vayer <titouan.vayer@irisa.fr>
# Hicham Janati <hicham.janati@inria.fr>
# Mokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com>
+# Alexander Tong <alexander.tong@yale.edu>
#
# License: MIT License
@@ -1346,12 +1347,17 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
err = 1
# build the convolution operator
+ # this is equivalent to blurring on horizontal then vertical directions
t = np.linspace(0, 1, A.shape[1])
[Y, X] = np.meshgrid(t, t)
xi1 = np.exp(-(X - Y)**2 / reg)
+ t = np.linspace(0, 1, A.shape[2])
+ [Y, X] = np.meshgrid(t, t)
+ xi2 = np.exp(-(X - Y)**2 / reg)
+
def K(x):
- return np.dot(np.dot(xi1, x), xi1)
+ return np.dot(np.dot(xi1, x), xi2)
while (err > stopThr and cpt < numItermax):
diff --git a/ot/gromov.py b/ot/gromov.py
index 9869341..43780a4 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -433,8 +433,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
where :
- M is the (ns,nt) metric cost matrix
- - :math:`f` is the regularization term ( and df is its gradient)
- - a and b are source and target weights (sum to 1)
+ - p and q are source and target weights (sum to 1)
- L is a loss function to account for the misfit between the similarity matrices
The algorithm used for solving the problem is conditional gradient as discussed in [24]_
@@ -453,17 +452,13 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
Distribution in the target space
loss_fun : str, optional
Loss function used for the solver
- max_iter : int, optional
- Max number of iterations
- tol : float, optional
- Stop threshold on error (>0)
- verbose : bool, optional
- Print information along iterations
- log : bool, optional
- record log if True
+ alpha : float, optional
+ Trade-off parameter (0 < alpha < 1)
armijo : bool, optional
If True the steps of the line-search is found via an armijo research. Else closed form is used.
If there is convergence issues use False.
+ log : bool, optional
+ record log if True
**kwargs : dict
parameters can be directly passed to the ot.optim.cg solver
@@ -493,11 +488,11 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
return gwggrad(constC, hC1, hC2, G)
if log:
- res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
+ res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
log['fgw_dist'] = log['loss'][::-1][0]
return res, log
else:
- return cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+ return cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
@@ -515,8 +510,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
where :
- M is the (ns,nt) metric cost matrix
- - :math:`f` is the regularization term ( and df is its gradient)
- - a and b are source and target weights (sum to 1)
+ - p and q are source and target weights (sum to 1)
- L is a loss function to account for the misfit between the similarity matrices
The algorithm used for solving the problem is conditional gradient as discussed in [1]_
@@ -534,17 +528,13 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
Distribution in the target space.
loss_fun : str, optional
Loss function used for the solver.
- max_iter : int, optional
- Max number of iterations
- tol : float, optional
- Stop threshold on error (>0)
- verbose : bool, optional
- Print information along iterations
- log : bool, optional
- Record log if True.
+ alpha : float, optional
+ Trade-off parameter (0 < alpha < 1)
armijo : bool, optional
If True the steps of the line-search is found via an armijo research.
Else closed form is used. If there is convergence issues use False.
+ log : bool, optional
+ Record log if True.
**kwargs : dict
Parameters can be directly pased to the ot.optim.cg solver.
@@ -573,7 +563,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
def df(G):
return gwggrad(constC, hC1, hC2, G)
- res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
+ res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
if log:
log['fgw_dist'] = log['loss'][::-1][0]
log['T'] = res
@@ -994,6 +984,16 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
Whether to fix the structure of the barycenter during the updates
fixed_features : bool
Whether to fix the feature of the barycenter during the updates
+ loss_fun : str
+ Loss function used for the solver either 'square_loss' or 'kl_loss'
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshol on error (>0).
+ verbose : bool, optional
+ Print information along iterations.
+ log : bool, optional
+ Record log if True.
init_C : ndarray, shape (N,N), optional
Initialization for the barycenters' structure matrix. If not set
a random init is used.
@@ -1082,7 +1082,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
T_temp = [t.T for t in T]
C = update_sructure_matrix(p, lambdas, T_temp, Cs)
- T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha,
+ T = [fused_gromov_wasserstein(Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha,
numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
# T is N,ns
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index cdd505d..f4f6861 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -12,16 +12,16 @@ Solvers for the original linear program OT problem
import multiprocessing
import sys
+
import numpy as np
from scipy.sparse import coo_matrix
-from .import cvx
-
+from . import cvx
+from .cvx import barycenter
# import compiled emd
from .emd_wrap import emd_c, check_result, emd_1d_sorted
-from ..utils import parmap
-from .cvx import barycenter
from ..utils import dist
+from ..utils import parmap
__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx',
'emd_1d', 'emd2_1d', 'wasserstein_1d']
@@ -458,7 +458,8 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
return res
-def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=None):
+def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100,
+ stopThr=1e-7, verbose=False, log=None):
"""
Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance)
@@ -525,8 +526,8 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
T_sum = np.zeros((k, d))
- for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()):
-
+ for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights,
+ weights.tolist()):
M_i = dist(X, measure_locations_i)
T_i = emd(b, measure_weights_i, M_i)
T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i)
@@ -651,12 +652,12 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
if b.ndim == 0 or len(b) == 0:
b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0]
- x_a_1d = x_a.reshape((-1, ))
- x_b_1d = x_b.reshape((-1, ))
+ x_a_1d = x_a.reshape((-1,))
+ x_b_1d = x_b.reshape((-1,))
perm_a = np.argsort(x_a_1d)
perm_b = np.argsort(x_b_1d)
- G_sorted, indices, cost = emd_1d_sorted(a, b,
+ G_sorted, indices, cost = emd_1d_sorted(a[perm_a], b[perm_b],
x_a_1d[perm_a], x_b_1d[perm_b],
metric=metric, p=p)
G = coo_matrix((G_sorted, (perm_a[indices[:, 0]], perm_b[indices[:, 1]])),