diff options
Diffstat (limited to 'geom_matching/wasserstein/include/dnn/parallel/utils.h')
-rw-r--r-- | geom_matching/wasserstein/include/dnn/parallel/utils.h | 20 |
1 files changed, 13 insertions, 7 deletions
diff --git a/geom_matching/wasserstein/include/dnn/parallel/utils.h b/geom_matching/wasserstein/include/dnn/parallel/utils.h index ba73814..7104ec3 100644 --- a/geom_matching/wasserstein/include/dnn/parallel/utils.h +++ b/geom_matching/wasserstein/include/dnn/parallel/utils.h @@ -1,8 +1,12 @@ -#ifndef PARALLEL_UTILS_H -#define PARALLEL_UTILS_H +#ifndef HERA_WS_PARALLEL_UTILS_H +#define HERA_WS_PARALLEL_UTILS_H #include "../utils.h" +namespace hera +{ +namespace ws +{ namespace dnn { // Assumes rng is synchronized across ranks @@ -15,11 +19,13 @@ namespace dnn typedef decltype(data[0]) T; shuffle(world, data, rng, [](T& x, T& y) { std::swap(x,y); }); } -} +} // dnn +} // ws +} // hera template<class DataVector, class RNGType, class SwapFunctor> void -dnn::shuffle(mpi::communicator& world, DataVector& data, RNGType& rng, const SwapFunctor& swap, DataVector empty) +hera::ws::dnn::shuffle(mpi::communicator& world, DataVector& data, RNGType& rng, const SwapFunctor& swap, DataVector empty) { // This is not a perfect shuffle: it dishes out data in chunks of 1/size. // (It can be interpreted as generating a bistochastic matrix by taking the @@ -42,7 +48,7 @@ dnn::shuffle(mpi::communicator& world, DataVector& data, RNGType& rng, const Swa RNGType local_rng(seed); // Shuffle local data - dnn::random_shuffle(data.begin(), data.end(), local_rng, swap); + hera::ws::dnn::random_shuffle(data.begin(), data.end(), local_rng, swap); // Decide how much of our data goes to i-th processor std::vector<size_t> out_counts(size); @@ -50,7 +56,7 @@ dnn::shuffle(mpi::communicator& world, DataVector& data, RNGType& rng, const Swa boost::counting_iterator<int>(size)); for (size_t i = 0; i < size; ++i) { - dnn::random_shuffle(ranks.begin(), ranks.end(), rng); + hera::ws::dnn::random_shuffle(ranks.begin(), ranks.end(), rng); ++out_counts[ranks[rank]]; } @@ -87,7 +93,7 @@ dnn::shuffle(mpi::communicator& world, DataVector& data, RNGType& rng, const Swa for(const DataVector& vec : incoming) for (size_t i = 0; i < vec.size(); ++i) data.push_back(vec[i]); - dnn::random_shuffle(data.begin(), data.end(), local_rng, swap); + hera::ws::dnn::random_shuffle(data.begin(), data.end(), local_rng, swap); // XXX: the final shuffle is irrelevant for our purposes. But it's also cheap. } |