summaryrefslogtreecommitdiff
path: root/ot/gromov.py
diff options
context:
space:
mode:
authorNicolas Courty <Nico@pc-mna-08.univ-ubs.fr>2017-08-31 16:44:18 +0200
committerNicolas Courty <Nico@pc-mna-08.univ-ubs.fr>2017-08-31 16:44:18 +0200
commit3007f1da1094f93fa4216386666085cf60316b04 (patch)
tree5e07b1674769403f2e09476b7d73f1e00a845384 /ot/gromov.py
parent0a68bf4e83ee9092c3f3878115fea894922d7d56 (diff)
Minor corrections suggested by @agramfort + new barycenter example + test function
Diffstat (limited to 'ot/gromov.py')
-rw-r--r--ot/gromov.py36
1 files changed, 16 insertions, 20 deletions
diff --git a/ot/gromov.py b/ot/gromov.py
index 7cf3b42..421ed3f 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -23,7 +23,7 @@ def square_loss(a, b):
Returns the value of L(a,b)=(1/2)*|a-b|^2
"""
- return (1 / 2) * (a - b)**2
+ return 0.5 * (a - b)**2
def kl_loss(a, b):
@@ -54,9 +54,9 @@ def tensor_square_loss(C1, C2, T):
Parameters
----------
- C1 : np.ndarray(ns,ns)
+ C1 : ndarray, shape (ns, ns)
Metric cost matrix in the source space
- C2 : np.ndarray(nt,nt)
+ C2 : ndarray, shape (nt, nt)
Metric costfr matrix in the target space
T : np.ndarray(ns,nt)
Coupling between source and target spaces
@@ -87,7 +87,7 @@ def tensor_square_loss(C1, C2, T):
return b
tens = -np.dot(h1(C1), T).dot(h2(C2).T)
- tens = tens - tens.min()
+ tens -= tens.min()
return np.array(tens)
@@ -112,9 +112,9 @@ def tensor_kl_loss(C1, C2, T):
Parameters
----------
- C1 : np.ndarray(ns,ns)
+ C1 : ndarray, shape (ns, ns)
Metric cost matrix in the source space
- C2 : np.ndarray(nt,nt)
+ C2 : ndarray, shape (nt, nt)
Metric costfr matrix in the target space
T : np.ndarray(ns,nt)
Coupling between source and target spaces
@@ -149,7 +149,7 @@ def tensor_kl_loss(C1, C2, T):
return np.log(b + 1e-15)
tens = -np.dot(h1(C1), T).dot(h2(C2).T)
- tens = tens - tens.min()
+ tens -= tens.min()
return np.array(tens)
@@ -175,9 +175,8 @@ def update_square_loss(p, lambdas, T, Cs):
"""
- tmpsum = np.sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s])
- for s in range(len(T))])
- ppt = np.dot(p, p.T)
+ tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) for s in range(len(T))])
+ ppt = np.outer(p, p)
return(np.divide(tmpsum, ppt))
@@ -203,9 +202,8 @@ def update_kl_loss(p, lambdas, T, Cs):
"""
- tmpsum = np.sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s])
- for s in range(len(T))])
- ppt = np.dot(p, p.T)
+ tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) for s in range(len(T))])
+ ppt = np.outer(p, p)
return(np.exp(np.divide(tmpsum, ppt)))
@@ -239,9 +237,9 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr
Parameters
----------
- C1 : np.ndarray(ns,ns)
+ C1 : ndarray, shape (ns, ns)
Metric cost matrix in the source space
- C2 : np.ndarray(nt,nt)
+ C2 : ndarray, shape (nt, nt)
Metric costfr matrix in the target space
p : np.ndarray(ns,)
distribution in the source space
@@ -271,7 +269,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr
C1 = np.asarray(C1, dtype=np.float64)
C2 = np.asarray(C2, dtype=np.float64)
- T = np.dot(p, q.T) # Initialization
+ T = np.outer(p, q) # Initialization
cpt = 0
err = 1
@@ -333,9 +331,9 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopTh
Parameters
----------
- C1 : np.ndarray(ns,ns)
+ C1 : ndarray, shape (ns, ns)
Metric cost matrix in the source space
- C2 : np.ndarray(nt,nt)
+ C2 : ndarray, shape (nt, nt)
Metric costfr matrix in the target space
p : np.ndarray(ns,)
distribution in the source space
@@ -434,8 +432,6 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, numItermax=1000
Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)]
lambdas = np.asarray(lambdas, dtype=np.float64)
- T = [0 for s in range(S)]
-
# Initialization of C : random SPD matrix
xalea = np.random.randn(N, 2)
C = dist(xalea, xalea)