summaryrefslogtreecommitdiff
path: root/ot/factored.py
blob: 8d66158764c719911bdd8b261ed4ba05538c482e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""
Factored OT solvers (low rank, cost or OT plan)
"""

# Author: Remi Flamary <remi.flamary@polytehnique.edu>
#
# License: MIT License

from .backend import get_backend
from .utils import dist
from .lp import emd
from .bregman import sinkhorn

__all__ = ['factored_optimal_transport']


def factored_optimal_transport(Xa, Xb, a=None, b=None, reg=0.0, r=100, X0=None, stopThr=1e-7, numItermax=100, verbose=False, log=False, **kwargs):
    r"""Solves factored OT problem and return OT plans and intermediate distribution

    This function solve the following OT problem [40]_

    .. math::
        \mathop{\arg \min}_\mu \quad  W_2^2(\mu_a,\mu)+ W_2^2(\mu,\mu_b)

    where :

    - :math:`\mu_a` and :math:`\mu_b`  are empirical distributions.
    - :math:`\mu` is an empirical distribution with r samples

    And returns the two OT plans between

    .. note:: This function is backend-compatible and will work on arrays
        from all compatible backends. But the algorithm uses the C++ CPU backend
        which can lead to copy overhead on GPU arrays.

    Uses the conditional gradient algorithm to solve the problem proposed in
    :ref:`[39] <references-weak>`.

    Parameters
    ----------
    Xa : (ns,d) array-like, float
        Source samples
    Xb : (nt,d) array-like, float
        Target samples
    a : (ns,) array-like, float
        Source histogram (uniform weight if empty list)
    b : (nt,) array-like, float
        Target histogram (uniform weight if empty list))
    numItermax : int, optional
        Max number of iterations
    stopThr : float, optional
        Stop threshold on the relative variation (>0)
    verbose : bool, optional
        Print information along iterations
    log : bool, optional
        record log if True


    Returns
    -------
    Ga: array-like, shape (ns, r)
        Optimal transportation matrix between source and the intermediate
        distribution
    Gb: array-like, shape (r, nt)
        Optimal transportation matrix between the intermediate and target
        distribution
    X: array-like, shape (r, d)
        Support of the intermediate distribution
    log: dict, optional
        If input log is true, a dictionary containing the cost and dual
        variables and exit status


    .. _references-factored:
    References
    ----------
    .. [40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger,
        G., & Weed, J. (2019, April). Statistical optimal transport via factored
        couplings. In The 22nd International Conference on Artificial
        Intelligence and Statistics (pp. 2454-2465). PMLR.

    See Also
    --------
    ot.bregman.sinkhorn : Entropic regularized OT
    ot.optim.cg : General regularized OT
    """

    nx = get_backend(Xa, Xb)

    n_a = Xa.shape[0]
    n_b = Xb.shape[0]
    d = Xa.shape[1]

    if a is None:
        a = nx.ones((n_a), type_as=Xa) / n_a
    if b is None:
        b = nx.ones((n_b), type_as=Xb) / n_b

    if X0 is None:
        X = nx.randn(r, d, type_as=Xa)
    else:
        X = X0

    w = nx.ones(r, type_as=Xa) / r

    def solve_ot(X1, X2, w1, w2):
        M = dist(X1, X2)
        if reg > 0:
            G, log = sinkhorn(w1, w2, M, reg, log=True, **kwargs)
            log['cost'] = nx.sum(G * M)
            return G, log
        else:
            return emd(w1, w2, M, log=True, **kwargs)

    norm_delta = []

    # solve the barycenter
    for i in range(numItermax):

        old_X = X

        # solve OT with template
        Ga, loga = solve_ot(Xa, X, a, w)
        Gb, logb = solve_ot(X, Xb, w, b)

        X = 0.5 * (nx.dot(Ga.T, Xa) + nx.dot(Gb, Xb)) * r

        delta = nx.norm(X - old_X)
        if delta < stopThr:
            break
        if log:
            norm_delta.append(delta)

    if log:
        log_dic = {'delta_iter': norm_delta,
                   'ua': loga['u'],
                   'va': loga['v'],
                   'ub': logb['u'],
                   'vb': logb['v'],
                   'costa': loga['cost'],
                   'costb': logb['cost'],
                   }
        return Ga, Gb, X, log_dic

    return Ga, Gb, X