summaryrefslogtreecommitdiff
path: root/ot/coot.py
blob: 66dd2c84a1e639dc29f0c9cb9848cb2159cc2420 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
# -*- coding: utf-8 -*-
"""
CO-Optimal Transport solver
"""

# Author: Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
#
# License: MIT License

import warnings
from .lp import emd
from .utils import list_to_array
from .backend import get_backend
from .bregman import sinkhorn


def co_optimal_transport(X, Y, wx_samp=None, wx_feat=None, wy_samp=None, wy_feat=None,
                         epsilon=0, alpha=0, M_samp=None, M_feat=None,
                         warmstart=None, nits_bcd=100, tol_bcd=1e-7, eval_bcd=1,
                         nits_ot=500, tol_sinkhorn=1e-7, method_sinkhorn="sinkhorn",
                         early_stopping_tol=1e-6, log=False, verbose=False):
    r"""Compute the CO-Optimal Transport between two matrices.

    Return the sample and feature transport plans between
    :math:`(\mathbf{X}, \mathbf{w}_{xs}, \mathbf{w}_{xf})` and
    :math:`(\mathbf{Y}, \mathbf{w}_{ys}, \mathbf{w}_{yf})`.

    The function solves the following CO-Optimal Transport (COOT) problem:

    .. math::
        \mathbf{COOT}_{\alpha, \varepsilon} = \mathop{\arg \min}_{\mathbf{P}, \mathbf{Q}}
        &\quad \sum_{i,j,k,l}
        (\mathbf{X}_{i,k} - \mathbf{Y}_{j,l})^2 \mathbf{P}_{i,j} \mathbf{Q}_{k,l}
        + \alpha_s \sum_{i,j} \mathbf{P}_{i,j} \mathbf{M^{(s)}}_{i, j} \\
        &+ \alpha_f \sum_{k, l} \mathbf{Q}_{k,l} \mathbf{M^{(f)}}_{k, l}
        + \varepsilon_s \mathbf{KL}(\mathbf{P} | \mathbf{w}_{xs} \mathbf{w}_{ys}^T)
        + \varepsilon_f \mathbf{KL}(\mathbf{Q} | \mathbf{w}_{xf} \mathbf{w}_{yf}^T)

    Where :

    - :math:`\mathbf{X}`: Data matrix in the source space
    - :math:`\mathbf{Y}`: Data matrix in the target space
    - :math:`\mathbf{M^{(s)}}`: Additional sample matrix
    - :math:`\mathbf{M^{(f)}}`: Additional feature matrix
    - :math:`\mathbf{w}_{xs}`: Distribution of the samples in the source space
    - :math:`\mathbf{w}_{xf}`: Distribution of the features in the source space
    - :math:`\mathbf{w}_{ys}`: Distribution of the samples in the target space
    - :math:`\mathbf{w}_{yf}`: Distribution of the features in the target space

    .. note:: This function allows epsilon to be zero.
              In that case, the :any:`ot.lp.emd` solver of POT will be used.

    Parameters
    ----------
    X : (n_sample_x, n_feature_x) array-like, float
        First input matrix.
    Y : (n_sample_y, n_feature_y) array-like, float
        Second input matrix.
    wx_samp : (n_sample_x, ) array-like, float, optional (default = None)
        Histogram assigned on rows (samples) of matrix X.
        Uniform distribution by default.
    wx_feat : (n_feature_x, ) array-like, float, optional (default = None)
        Histogram assigned on columns (features) of matrix X.
        Uniform distribution by default.
    wy_samp : (n_sample_y, ) array-like, float, optional (default = None)
        Histogram assigned on rows (samples) of matrix Y.
        Uniform distribution by default.
    wy_feat : (n_feature_y, ) array-like, float, optional (default = None)
        Histogram assigned on columns (features) of matrix Y.
        Uniform distribution by default.
    epsilon : scalar or indexable object of length 2, float or int, optional (default = 0)
        Regularization parameters for entropic approximation of sample and feature couplings.
        Allow the case where epsilon contains 0. In that case, the EMD solver is used instead of
        Sinkhorn solver. If epsilon is scalar, then the same epsilon is applied to
        both regularization of sample and feature couplings.
    alpha : scalar or indexable object of length 2, float or int, optional (default = 0)
        Coeffficient parameter of linear terms with respect to the sample and feature couplings.
        If alpha is scalar, then the same alpha is applied to both linear terms.
    M_samp : (n_sample_x, n_sample_y), float, optional (default = None)
        Sample matrix with respect to the linear term on sample coupling.
    M_feat : (n_feature_x, n_feature_y), float, optional (default = None)
        Feature matrix with respect to the linear term on feature coupling.
    warmstart : dictionary, optional (default = None)
        Contains 4 keys:
            - "duals_sample" and "duals_feature" whose values are
              tuples of 2 vectors of size (n_sample_x, n_sample_y) and (n_feature_x, n_feature_y).
              Initialization of sample and feature dual vectors
              if using Sinkhorn algorithm. Zero vectors by default.

            - "pi_sample" and "pi_feature" whose values are matrices
              of size (n_sample_x, n_sample_y) and (n_feature_x, n_feature_y).
              Initialization of sample and feature couplings.
              Uniform distributions by default.
    nits_bcd : int, optional (default = 100)
        Number of Block Coordinate Descent (BCD) iterations to solve COOT.
    tol_bcd : float, optional (default = 1e-7)
        Tolerance of BCD scheme. If the L1-norm between the current and previous
        sample couplings is under this threshold, then stop BCD scheme.
    eval_bcd : int, optional (default = 1)
        Multiplier of iteration at which the COOT cost is evaluated. For example,
        if `eval_bcd = 8`, then the cost is calculated at iterations 8, 16, 24, etc...
    nits_ot : int, optional (default = 100)
        Number of iterations to solve each of the
        two optimal transport problems in each BCD iteration.
    tol_sinkhorn : float, optional (default = 1e-7)
        Tolerance of Sinkhorn algorithm to stop the Sinkhorn scheme for
        entropic optimal transport problem (if any) in each BCD iteration.
        Only triggered when Sinkhorn solver is used.
    method_sinkhorn : string, optional (default = "sinkhorn")
        Method used in POT's `ot.sinkhorn` solver.
        Only support "sinkhorn" and "sinkhorn_log".
    early_stopping_tol : float, optional (default = 1e-6)
        Tolerance for the early stopping. If the absolute difference between
        the last 2 recorded COOT distances is under this tolerance, then stop BCD scheme.
    log : bool, optional (default = False)
        If True then the cost and 4 dual vectors, including
        2 from sample and 2 from feature couplings, are recorded.
    verbose : bool, optional (default = False)
        If True then print the COOT cost at every multiplier of `eval_bcd`-th iteration.

    Returns
    -------
    pi_samp : (n_sample_x, n_sample_y) array-like, float
        Sample coupling matrix.
    pi_feat : (n_feature_x, n_feature_y) array-like, float
        Feature coupling matrix.
    log : dictionary, optional
        Returned if `log` is True. The keys are:
            duals_sample : (n_sample_x, n_sample_y) tuple, float
                Pair of dual vectors when solving OT problem w.r.t the sample coupling.
            duals_feature : (n_feature_x, n_feature_y) tuple, float
                Pair of dual vectors when solving OT problem w.r.t the feature coupling.
            distances : list, float
                List of COOT distances.

    References
    ----------
    .. [49] I. Redko, T. Vayer, R. Flamary, and N. Courty, CO-Optimal Transport,
        Advances in Neural Information Processing ny_sampstems, 33 (2020).
    """

    def compute_kl(p, q):
        kl = nx.sum(p * nx.log(p + 1.0 * (p == 0))) - nx.sum(p * nx.log(q))
        return kl

    # Main function

    if method_sinkhorn not in ["sinkhorn", "sinkhorn_log"]:
        raise ValueError(
            "Method {} is not supported in CO-Optimal Transport.".format(method_sinkhorn))

    X, Y = list_to_array(X, Y)
    nx = get_backend(X, Y)

    if isinstance(epsilon, float) or isinstance(epsilon, int):
        eps_samp, eps_feat = epsilon, epsilon
    else:
        if len(epsilon) != 2:
            raise ValueError("Epsilon must be either a scalar or an indexable object of length 2.")
        else:
            eps_samp, eps_feat = epsilon[0], epsilon[1]

    if isinstance(alpha, float) or isinstance(alpha, int):
        alpha_samp, alpha_feat = alpha, alpha
    else:
        if len(alpha) != 2:
            raise ValueError("Alpha must be either a scalar or an indexable object of length 2.")
        else:
            alpha_samp, alpha_feat = alpha[0], alpha[1]

    # constant input variables
    if M_samp is None or alpha_samp == 0:
        M_samp, alpha_samp = 0, 0
    if M_feat is None or alpha_feat == 0:
        M_feat, alpha_feat = 0, 0

    nx_samp, nx_feat = X.shape
    ny_samp, ny_feat = Y.shape

    # measures on rows and columns
    if wx_samp is None:
        wx_samp = nx.ones(nx_samp, type_as=X) / nx_samp
    if wx_feat is None:
        wx_feat = nx.ones(nx_feat, type_as=X) / nx_feat
    if wy_samp is None:
        wy_samp = nx.ones(ny_samp, type_as=Y) / ny_samp
    if wy_feat is None:
        wy_feat = nx.ones(ny_feat, type_as=Y) / ny_feat

    wxy_samp = wx_samp[:, None] * wy_samp[None, :]
    wxy_feat = wx_feat[:, None] * wy_feat[None, :]

    # pre-calculate cost constants
    XY_sqr = (X ** 2 @ wx_feat)[:, None] + (Y ** 2 @
                                            wy_feat)[None, :] + alpha_samp * M_samp
    XY_sqr_T = ((X.T)**2 @ wx_samp)[:, None] + ((Y.T)
                                                ** 2 @ wy_samp)[None, :] + alpha_feat * M_feat

    # initialize coupling and dual vectors
    if warmstart is None:
        pi_samp, pi_feat = wxy_samp, wxy_feat  # shape nx_samp x ny_samp and nx_feat x ny_feat
        duals_samp = (nx.zeros(nx_samp, type_as=X), nx.zeros(
            ny_samp, type_as=Y))  # shape nx_samp, ny_samp
        duals_feat = (nx.zeros(nx_feat, type_as=X), nx.zeros(
            ny_feat, type_as=Y))  # shape nx_feat, ny_feat
    else:
        pi_samp, pi_feat = warmstart["pi_sample"], warmstart["pi_feature"]
        duals_samp, duals_feat = warmstart["duals_sample"], warmstart["duals_feature"]

    # initialize log
    list_coot = [float("inf")]
    err = tol_bcd + 1e-3

    for idx in range(nits_bcd):
        pi_samp_prev = nx.copy(pi_samp)

        # update sample coupling
        ot_cost = XY_sqr - 2 * X @ pi_feat @ Y.T  # size nx_samp x ny_samp
        if eps_samp > 0:
            pi_samp, dict_log = sinkhorn(a=wx_samp, b=wy_samp, M=ot_cost, reg=eps_samp, method=method_sinkhorn,
                                         numItermax=nits_ot, stopThr=tol_sinkhorn, log=True, warmstart=duals_samp)
            duals_samp = (nx.log(dict_log["u"]), nx.log(dict_log["v"]))
        elif eps_samp == 0:
            pi_samp, dict_log = emd(
                a=wx_samp, b=wy_samp, M=ot_cost, numItermax=nits_ot, log=True)
            duals_samp = (dict_log["u"], dict_log["v"])
        # update feature coupling
        ot_cost = XY_sqr_T - 2 * X.T @ pi_samp @ Y  # size nx_feat x ny_feat
        if eps_feat > 0:
            pi_feat, dict_log = sinkhorn(a=wx_feat, b=wy_feat, M=ot_cost, reg=eps_feat, method=method_sinkhorn,
                                         numItermax=nits_ot, stopThr=tol_sinkhorn, log=True, warmstart=duals_feat)
            duals_feat = (nx.log(dict_log["u"]), nx.log(dict_log["v"]))
        elif eps_feat == 0:
            pi_feat, dict_log = emd(
                a=wx_feat, b=wy_feat, M=ot_cost, numItermax=nits_ot, log=True)
            duals_feat = (dict_log["u"], dict_log["v"])

        if idx % eval_bcd == 0:
            # update error
            err = nx.sum(nx.abs(pi_samp - pi_samp_prev))

            # COOT part
            coot = nx.sum(ot_cost * pi_feat)
            if alpha_samp != 0:
                coot = coot + alpha_samp * nx.sum(M_samp * pi_samp)
            # Entropic part
            if eps_samp != 0:
                coot = coot + eps_samp * compute_kl(pi_samp, wxy_samp)
            if eps_feat != 0:
                coot = coot + eps_feat * compute_kl(pi_feat, wxy_feat)
            list_coot.append(coot)

            if err < tol_bcd or abs(list_coot[-2] - list_coot[-1]) < early_stopping_tol:
                break

            if verbose:
                print(
                    "CO-Optimal Transport cost at iteration {}: {}".format(idx + 1, coot))

    # sanity check
    if nx.sum(nx.isnan(pi_samp)) > 0 or nx.sum(nx.isnan(pi_feat)) > 0:
        warnings.warn("There is NaN in coupling.")

    if log:
        dict_log = {"duals_sample": duals_samp,
                    "duals_feature": duals_feat,
                    "distances": list_coot[1:]}

        return pi_samp, pi_feat, dict_log

    else:
        return pi_samp, pi_feat


def co_optimal_transport2(X, Y, wx_samp=None, wx_feat=None, wy_samp=None, wy_feat=None,
                          epsilon=0, alpha=0, M_samp=None, M_feat=None,
                          warmstart=None, log=False, verbose=False, early_stopping_tol=1e-6,
                          nits_bcd=100, tol_bcd=1e-7, eval_bcd=1,
                          nits_ot=500, tol_sinkhorn=1e-7,
                          method_sinkhorn="sinkhorn"):
    r"""Compute the CO-Optimal Transport distance between two measures.

    Returns the CO-Optimal Transport distance between
    :math:`(\mathbf{X}, \mathbf{w}_{xs}, \mathbf{w}_{xf})` and
    :math:`(\mathbf{Y}, \mathbf{w}_{ys}, \mathbf{w}_{yf})`.

    The function solves the following CO-Optimal Transport (COOT) problem:

    .. math::
        \mathbf{COOT}_{\alpha, \varepsilon} = \mathop{\arg \min}_{\mathbf{P}, \mathbf{Q}}
        &\quad \sum_{i,j,k,l}
        (\mathbf{X}_{i,k} - \mathbf{Y}_{j,l})^2 \mathbf{P}_{i,j} \mathbf{Q}_{k,l}
        + \alpha_1 \sum_{i,j} \mathbf{P}_{i,j} \mathbf{M^{(s)}}_{i, j} \\
        &+ \alpha_2 \sum_{k, l} \mathbf{Q}_{k,l} \mathbf{M^{(f)}}_{k, l}
        + \varepsilon_1 \mathbf{KL}(\mathbf{P} | \mathbf{w}_{xs} \mathbf{w}_{ys}^T)
        + \varepsilon_2 \mathbf{KL}(\mathbf{Q} | \mathbf{w}_{xf} \mathbf{w}_{yf}^T)

    Where :

    - :math:`\mathbf{X}`: Data matrix in the source space
    - :math:`\mathbf{Y}`: Data matrix in the target space
    - :math:`\mathbf{M^{(s)}}`: Additional sample matrix
    - :math:`\mathbf{M^{(f)}}`: Additional feature matrix
    - :math:`\mathbf{w}_{xs}`: Distribution of the samples in the source space
    - :math:`\mathbf{w}_{xf}`: Distribution of the features in the source space
    - :math:`\mathbf{w}_{ys}`: Distribution of the samples in the target space
    - :math:`\mathbf{w}_{yf}`: Distribution of the features in the target space

    .. note:: This function allows epsilon to be zero.
              In that case, the :any:`ot.lp.emd` solver of POT will be used.

    Parameters
    ----------
    X : (n_sample_x, n_feature_x) array-like, float
        First input matrix.
    Y : (n_sample_y, n_feature_y) array-like, float
        Second input matrix.
    wx_samp : (n_sample_x, ) array-like, float, optional (default = None)
        Histogram assigned on rows (samples) of matrix X.
        Uniform distribution by default.
    wx_feat : (n_feature_x, ) array-like, float, optional (default = None)
        Histogram assigned on columns (features) of matrix X.
        Uniform distribution by default.
    wy_samp : (n_sample_y, ) array-like, float, optional (default = None)
        Histogram assigned on rows (samples) of matrix Y.
        Uniform distribution by default.
    wy_feat : (n_feature_y, ) array-like, float, optional (default = None)
        Histogram assigned on columns (features) of matrix Y.
        Uniform distribution by default.
    epsilon : scalar or indexable object of length 2, float or int, optional (default = 0)
        Regularization parameters for entropic approximation of sample and feature couplings.
        Allow the case where epsilon contains 0. In that case, the EMD solver is used instead of
        Sinkhorn solver. If epsilon is scalar, then the same epsilon is applied to
        both regularization of sample and feature couplings.
    alpha : scalar or indexable object of length 2, float or int, optional (default = 0)
        Coeffficient parameter of linear terms with respect to the sample and feature couplings.
        If alpha is scalar, then the same alpha is applied to both linear terms.
    M_samp : (n_sample_x, n_sample_y), float, optional (default = None)
        Sample matrix with respect to the linear term on sample coupling.
    M_feat : (n_feature_x, n_feature_y), float, optional (default = None)
        Feature matrix with respect to the linear term on feature coupling.
    warmstart : dictionary, optional (default = None)
        Contains 4 keys:
            - "duals_sample" and "duals_feature" whose values are
            tuples of 2 vectors of size (n_sample_x, n_sample_y) and (n_feature_x, n_feature_y).
            Initialization of sample and feature dual vectors
            if using Sinkhorn algorithm. Zero vectors by default.

            - "pi_sample" and "pi_feature" whose values are matrices
            of size (n_sample_x, n_sample_y) and (n_feature_x, n_feature_y).
            Initialization of sample and feature couplings.
            Uniform distributions by default.
    nits_bcd : int, optional (default = 100)
        Number of Block Coordinate Descent (BCD) iterations to solve COOT.
    tol_bcd : float, optional (default = 1e-7)
        Tolerance of BCD scheme. If the L1-norm between the current and previous
        sample couplings is under this threshold, then stop BCD scheme.
    eval_bcd : int, optional (default = 1)
        Multiplier of iteration at which the COOT cost is evaluated. For example,
        if `eval_bcd = 8`, then the cost is calculated at iterations 8, 16, 24, etc...
    nits_ot : int, optional (default = 100)
        Number of iterations to solve each of the
        two optimal transport problems in each BCD iteration.
    tol_sinkhorn : float, optional (default = 1e-7)
        Tolerance of Sinkhorn algorithm to stop the Sinkhorn scheme for
        entropic optimal transport problem (if any) in each BCD iteration.
        Only triggered when Sinkhorn solver is used.
    method_sinkhorn : string, optional (default = "sinkhorn")
        Method used in POT's `ot.sinkhorn` solver.
        Only support "sinkhorn" and "sinkhorn_log".
    early_stopping_tol : float, optional (default = 1e-6)
        Tolerance for the early stopping. If the absolute difference between
        the last 2 recorded COOT distances is under this tolerance, then stop BCD scheme.
    log : bool, optional (default = False)
        If True then the cost and 4 dual vectors, including
        2 from sample and 2 from feature couplings, are recorded.
    verbose : bool, optional (default = False)
        If True then print the COOT cost at every multiplier of `eval_bcd`-th iteration.

    Returns
    -------
    float
        CO-Optimal Transport distance.
    dict
        Contains logged informations from :any:`co_optimal_transport` solver.
        Only returned if `log` parameter is True

    References
    ----------
    .. [47] I. Redko, T. Vayer, R. Flamary, and N. Courty, CO-Optimal Transport,
        Advances in Neural Information Processing ny_sampstems, 33 (2020).
    """

    pi_samp, pi_feat, dict_log = co_optimal_transport(X=X, Y=Y, wx_samp=wx_samp, wx_feat=wx_feat, wy_samp=wy_samp,
                                                      wy_feat=wy_feat, epsilon=epsilon, alpha=alpha, M_samp=M_samp,
                                                      M_feat=M_feat, warmstart=warmstart, nits_bcd=nits_bcd,
                                                      tol_bcd=tol_bcd, eval_bcd=eval_bcd, nits_ot=nits_ot,
                                                      tol_sinkhorn=tol_sinkhorn, method_sinkhorn=method_sinkhorn,
                                                      early_stopping_tol=early_stopping_tol,
                                                      log=True, verbose=verbose)

    X, Y = list_to_array(X, Y)
    nx = get_backend(X, Y)

    nx_samp, nx_feat = X.shape
    ny_samp, ny_feat = Y.shape

    # measures on rows and columns
    if wx_samp is None:
        wx_samp = nx.ones(nx_samp, type_as=X) / nx_samp
    if wx_feat is None:
        wx_feat = nx.ones(nx_feat, type_as=X) / nx_feat
    if wy_samp is None:
        wy_samp = nx.ones(ny_samp, type_as=Y) / ny_samp
    if wy_feat is None:
        wy_feat = nx.ones(ny_feat, type_as=Y) / ny_feat

    vx_samp, vy_samp = dict_log["duals_sample"]
    vx_feat, vy_feat = dict_log["duals_feature"]

    gradX = 2 * X * (wx_samp[:, None] * wx_feat[None, :]) - \
        2 * pi_samp @ Y @ pi_feat.T  # shape (nx_samp, nx_feat)
    gradY = 2 * Y * (wy_samp[:, None] * wy_feat[None, :]) - \
        2 * pi_samp.T @ X @ pi_feat  # shape (ny_samp, ny_feat)

    coot = dict_log["distances"][-1]
    coot = nx.set_gradients(coot, (wx_samp, wx_feat, wy_samp, wy_feat, X, Y),
                            (vx_samp, vx_feat, vy_samp, vy_feat, gradX, gradY))

    if log:
        return coot, dict_log

    else:
        return coot