summaryrefslogtreecommitdiff
path: root/ot/weak.py
blob: f7d5b2339000ae1aa5e3784d9d0b0bb3e6bf3a9e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
Weak optimal ransport solvers
"""

# Author: Remi Flamary <remi.flamary@polytehnique.edu>
#
# License: MIT License

from .backend import get_backend
from .optim import cg
import numpy as np

__all__ = ['weak_optimal_transport']


def weak_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, G0=None, **kwargs):
    r"""Solves the weak optimal transport problem between two empirical distributions


    .. math::
        \gamma = \mathop{\arg \min}_\gamma \quad \|X_a-diag(1/a)\gammaX_b\|_F^2

        s.t. \ \gamma \mathbf{1} = \mathbf{a}

             \gamma^T \mathbf{1} = \mathbf{b}

             \gamma \geq 0

    where :

    - :math:`X_a` :math:`X_b`  are the sample matrices.
    - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights


    .. note:: This function is backend-compatible and will work on arrays
        from all compatible backends. But the algorithm uses the C++ CPU backend
        which can lead to copy overhead on GPU arrays.

    Uses the conditional gradient algorithm to solve the problem proposed
    in :ref:`[39] <references-weak>`.

    Parameters
    ----------
    Xa : (ns,d) array-like, float
        Source samples
    Xb : (nt,d) array-like, float
        Target samples
    a : (ns,) array-like, float
        Source histogram (uniform weight if empty list)
    b : (nt,) array-like, float
        Target histogram (uniform weight if empty list))
    numItermax : int, optional
        Max number of iterations
    numItermaxEmd : int, optional
        Max number of iterations for emd
    stopThr : float, optional
        Stop threshold on the relative variation (>0)
    stopThr2 : float, optional
        Stop threshold on the absolute variation (>0)
    verbose : bool, optional
        Print information along iterations
    log : bool, optional
        record log if True


    Returns
    -------
    gamma: array-like, shape (ns, nt)
        Optimal transportation matrix for the given
        parameters
    log: dict, optional
        If input log is true, a dictionary containing the
        cost and dual variables and exit status


    .. _references-weak:
    References
    ----------
    .. [39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017).
            Kantorovich duality for general transport costs and applications.
            Journal of Functional Analysis, 273(11), 3327-3405.

    See Also
    --------
    ot.bregman.sinkhorn : Entropic regularized OT
    ot.optim.cg : General regularized OT
    """

    nx = get_backend(Xa, Xb)

    Xa2 = nx.to_numpy(Xa)
    Xb2 = nx.to_numpy(Xb)

    if a is None:
        a2 = np.ones((Xa.shape[0])) / Xa.shape[0]
    else:
        a2 = nx.to_numpy(a)
    if b is None:
        b2 = np.ones((Xb.shape[0])) / Xb.shape[0]
    else:
        b2 = nx.to_numpy(b)

    # init uniform
    if G0 is None:
        T0 = a2[:, None] * b2[None, :]
    else:
        T0 = nx.to_numpy(G0)

    # weak OT loss
    def f(T):
        return np.dot(a2, np.sum((Xa2 - np.dot(T, Xb2) / a2[:, None])**2, 1))

    # weak OT gradient
    def df(T):
        return -2 * np.dot(Xa2 - np.dot(T, Xb2) / a2[:, None], Xb2.T)

    # solve with conditional gradient and return solution
    if log:
        res, log = cg(a2, b2, 0, 1, f, df, T0, log=log, verbose=verbose, **kwargs)
        log['u'] = nx.from_numpy(log['u'], type_as=Xa)
        log['v'] = nx.from_numpy(log['v'], type_as=Xb)
        return nx.from_numpy(res, type_as=Xa), log
    else:
        return nx.from_numpy(cg(a2, b2, 0, 1, f, df, T0, log=log, verbose=verbose, **kwargs), type_as=Xa)