summaryrefslogtreecommitdiff
path: root/geom_matching/wasserstein/include/wasserstein_pure_geom.hpp
blob: 2a575990dd08ffb691b54ce72d6baf860f83e316 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#ifndef WASSERSTEIN_PURE_GEOM_HPP
#define WASSERSTEIN_PURE_GEOM_HPP

#define WASSERSTEIN_PURE_GEOM


#include "diagram_reader.h"
#include "auction_oracle_kdtree_pure_geom.h"
#include "auction_runner_gs.h"
#include "auction_runner_jac.h"

namespace hera
{
namespace ws
{

    template <class Real>
    using DynamicTraits = typename hera::ws::dnn::DynamicPointTraits<Real>;

    template <class Real>
    using DynamicPoint = typename hera::ws::dnn::DynamicPointTraits<Real>::PointType;

    template <class Real>
    using DynamicPointVector = typename hera::ws::dnn::DynamicPointVector<Real>;

    template <class Real>
    using AuctionRunnerGSR = typename hera::ws::AuctionRunnerGS<Real, hera::ws::AuctionOracleKDTreePureGeom<Real>, hera::ws::dnn::DynamicPointVector<Real>>;

    template <class Real>
    using AuctionRunnerJacR = typename hera::ws::AuctionRunnerJac<Real, hera::ws::AuctionOracleKDTreePureGeom<Real>, hera::ws::dnn::DynamicPointVector<Real>>;


double wasserstein_cost(const DynamicPointVector<double>& set_A, const DynamicPointVector<double>& set_B, const AuctionParams<double>& params)
{
    if (params.wasserstein_power < 1.0) {
        throw std::runtime_error("Bad q in Wasserstein " + std::to_string(params.wasserstein_power));
    }

    if (params.delta < 0.0) {
        throw std::runtime_error("Bad delta in Wasserstein " + std::to_string(params.delta));
    }

    if (params.initial_epsilon < 0.0) {
        throw std::runtime_error("Bad initial epsilon in Wasserstein" + std::to_string(params.initial_epsilon));
    }

    if (params.epsilon_common_ratio < 0.0) {
        throw std::runtime_error("Bad epsilon factor in Wasserstein " + std::to_string(params.epsilon_common_ratio));
    }

    if (set_A.size() != set_B.size()) {
        throw std::runtime_error("Different cardinalities of point clouds: " + std::to_string(set_A.size()) + " != " +  std::to_string(set_B.size()));
    }

    DynamicTraits<double> traits(params.dim);

    DynamicPointVector<double> set_A_copy(set_A);
    DynamicPointVector<double> set_B_copy(set_B);

    // set point id to the index in vector
    for(size_t i = 0; i < set_A.size(); ++i) {
        traits.id(set_A_copy[i]) = i;
        traits.id(set_B_copy[i]) = i;
    }

    if (params.max_bids_per_round == 1) {
        hera::ws::AuctionRunnerGSR<double> auction(set_A_copy, set_B_copy, params);
        auction.run_auction();
        return auction.get_wasserstein_cost();
    } else {
        hera::ws::AuctionRunnerJacR<double> auction(set_A_copy, set_B_copy, params);
        auction.run_auction();
        return auction.get_wasserstein_cost();
    }

}

double wasserstein_dist(const DynamicPointVector<double>& set_A, const DynamicPointVector<double>& set_B, const AuctionParams<double>& params)
{
    return std::pow(wasserstein_cost(set_A, set_B, params), 1.0 / params.wasserstein_power);
}

} // ws
} // hera


#endif