diff options
Diffstat (limited to 'ot/factored.py')
-rw-r--r-- | ot/factored.py | 145 |
1 files changed, 145 insertions, 0 deletions
diff --git a/ot/factored.py b/ot/factored.py new file mode 100644 index 0000000..abc2445 --- /dev/null +++ b/ot/factored.py @@ -0,0 +1,145 @@ +""" +Factored OT solvers (low rank, cost or OT plan) +""" + +# Author: Remi Flamary <remi.flamary@polytehnique.edu> +# +# License: MIT License + +from .backend import get_backend +from .utils import dist +from .lp import emd +from .bregman import sinkhorn + +__all__ = ['factored_optimal_transport'] + + +def factored_optimal_transport(Xa, Xb, a=None, b=None, reg=0.0, r=100, X0=None, stopThr=1e-7, numItermax=100, verbose=False, log=False, **kwargs): + r"""Solves factored OT problem and return OT plans and intermediate distribution + + This function solve the following OT problem [40]_ + + .. math:: + \mathop{\arg \min}_\mu \quad W_2^2(\mu_a,\mu)+ W_2^2(\mu,\mu_b) + + where : + + - :math:`\mu_a` and :math:`\mu_b` are empirical distributions. + - :math:`\mu` is an empirical distribution with r samples + + And returns the two OT plans between + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + + Uses the conditional gradient algorithm to solve the problem proposed in + :ref:`[39] <references-weak>`. + + Parameters + ---------- + Xa : (ns,d) array-like, float + Source samples + Xb : (nt,d) array-like, float + Target samples + a : (ns,) array-like, float + Source histogram (uniform weight if empty list) + b : (nt,) array-like, float + Target histogram (uniform weight if empty list)) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on the relative variation (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + Ga: array-like, shape (ns, r) + Optimal transportation matrix between source and the intermediate + distribution + Gb: array-like, shape (r, nt) + Optimal transportation matrix between the intermediate and target + distribution + X: array-like, shape (r, d) + Support of the intermediate distribution + log: dict, optional + If input log is true, a dictionary containing the cost and dual + variables and exit status + + + .. _references-factored: + References + ---------- + .. [40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, + G., & Weed, J. (2019, April). Statistical optimal transport via factored + couplings. In The 22nd International Conference on Artificial + Intelligence and Statistics (pp. 2454-2465). PMLR. + + See Also + -------- + ot.bregman.sinkhorn : Entropic regularized OT ot.optim.cg : General + regularized OT + """ + + nx = get_backend(Xa, Xb) + + n_a = Xa.shape[0] + n_b = Xb.shape[0] + d = Xa.shape[1] + + if a is None: + a = nx.ones((n_a), type_as=Xa) / n_a + if b is None: + b = nx.ones((n_b), type_as=Xb) / n_b + + if X0 is None: + X = nx.randn(r, d, type_as=Xa) + else: + X = X0 + + w = nx.ones(r, type_as=Xa) / r + + def solve_ot(X1, X2, w1, w2): + M = dist(X1, X2) + if reg > 0: + G, log = sinkhorn(w1, w2, M, reg, log=True, **kwargs) + log['cost'] = nx.sum(G * M) + return G, log + else: + return emd(w1, w2, M, log=True, **kwargs) + + norm_delta = [] + + # solve the barycenter + for i in range(numItermax): + + old_X = X + + # solve OT with template + Ga, loga = solve_ot(Xa, X, a, w) + Gb, logb = solve_ot(X, Xb, w, b) + + X = 0.5 * (nx.dot(Ga.T, Xa) + nx.dot(Gb, Xb)) * r + + delta = nx.norm(X - old_X) + if delta < stopThr: + break + if log: + norm_delta.append(delta) + + if log: + log_dic = {'delta_iter': norm_delta, + 'ua': loga['u'], + 'va': loga['v'], + 'ub': logb['u'], + 'vb': logb['v'], + 'costa': loga['cost'], + 'costb': logb['cost'], + } + return Ga, Gb, X, log_dic + + return Ga, Gb, X |