summaryrefslogtreecommitdiff
path: root/ot/weak.py
diff options
context:
space:
mode:
authorGard Spreemann <gspr@nonempty.org>2022-04-27 11:49:23 +0200
committerGard Spreemann <gspr@nonempty.org>2022-04-27 11:49:23 +0200
commit35bd2c98b642df78638d7d733bc1a89d873db1de (patch)
tree6bc637624004713808d3097b95acdccbb9608e52 /ot/weak.py
parentc4753bd3f74139af8380127b66b484bc09b50661 (diff)
parenteccb1386eea52b94b82456d126bd20cbe3198e05 (diff)
Merge tag '0.8.2' into dfsg/latest
Diffstat (limited to 'ot/weak.py')
-rw-r--r--ot/weak.py124
1 files changed, 124 insertions, 0 deletions
diff --git a/ot/weak.py b/ot/weak.py
new file mode 100644
index 0000000..f7d5b23
--- /dev/null
+++ b/ot/weak.py
@@ -0,0 +1,124 @@
+"""
+Weak optimal ransport solvers
+"""
+
+# Author: Remi Flamary <remi.flamary@polytehnique.edu>
+#
+# License: MIT License
+
+from .backend import get_backend
+from .optim import cg
+import numpy as np
+
+__all__ = ['weak_optimal_transport']
+
+
+def weak_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, G0=None, **kwargs):
+ r"""Solves the weak optimal transport problem between two empirical distributions
+
+
+ .. math::
+ \gamma = \mathop{\arg \min}_\gamma \quad \|X_a-diag(1/a)\gammaX_b\|_F^2
+
+ s.t. \ \gamma \mathbf{1} = \mathbf{a}
+
+ \gamma^T \mathbf{1} = \mathbf{b}
+
+ \gamma \geq 0
+
+ where :
+
+ - :math:`X_a` :math:`X_b` are the sample matrices.
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
+
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+
+ Uses the conditional gradient algorithm to solve the problem proposed
+ in :ref:`[39] <references-weak>`.
+
+ Parameters
+ ----------
+ Xa : (ns,d) array-like, float
+ Source samples
+ Xb : (nt,d) array-like, float
+ Target samples
+ a : (ns,) array-like, float
+ Source histogram (uniform weight if empty list)
+ b : (nt,) array-like, float
+ Target histogram (uniform weight if empty list))
+ numItermax : int, optional
+ Max number of iterations
+ numItermaxEmd : int, optional
+ Max number of iterations for emd
+ stopThr : float, optional
+ Stop threshold on the relative variation (>0)
+ stopThr2 : float, optional
+ Stop threshold on the absolute variation (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma: array-like, shape (ns, nt)
+ Optimal transportation matrix for the given
+ parameters
+ log: dict, optional
+ If input log is true, a dictionary containing the
+ cost and dual variables and exit status
+
+
+ .. _references-weak:
+ References
+ ----------
+ .. [39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017).
+ Kantorovich duality for general transport costs and applications.
+ Journal of Functional Analysis, 273(11), 3327-3405.
+
+ See Also
+ --------
+ ot.bregman.sinkhorn : Entropic regularized OT
+ ot.optim.cg : General regularized OT
+ """
+
+ nx = get_backend(Xa, Xb)
+
+ Xa2 = nx.to_numpy(Xa)
+ Xb2 = nx.to_numpy(Xb)
+
+ if a is None:
+ a2 = np.ones((Xa.shape[0])) / Xa.shape[0]
+ else:
+ a2 = nx.to_numpy(a)
+ if b is None:
+ b2 = np.ones((Xb.shape[0])) / Xb.shape[0]
+ else:
+ b2 = nx.to_numpy(b)
+
+ # init uniform
+ if G0 is None:
+ T0 = a2[:, None] * b2[None, :]
+ else:
+ T0 = nx.to_numpy(G0)
+
+ # weak OT loss
+ def f(T):
+ return np.dot(a2, np.sum((Xa2 - np.dot(T, Xb2) / a2[:, None])**2, 1))
+
+ # weak OT gradient
+ def df(T):
+ return -2 * np.dot(Xa2 - np.dot(T, Xb2) / a2[:, None], Xb2.T)
+
+ # solve with conditional gradient and return solution
+ if log:
+ res, log = cg(a2, b2, 0, 1, f, df, T0, log=log, verbose=verbose, **kwargs)
+ log['u'] = nx.from_numpy(log['u'], type_as=Xa)
+ log['v'] = nx.from_numpy(log['v'], type_as=Xb)
+ return nx.from_numpy(res, type_as=Xa), log
+ else:
+ return nx.from_numpy(cg(a2, b2, 0, 1, f, df, T0, log=log, verbose=verbose, **kwargs), type_as=Xa)