summaryrefslogtreecommitdiff
path: root/wasserstein/include/dnn
diff options
context:
space:
mode:
Diffstat (limited to 'wasserstein/include/dnn')
-rw-r--r--wasserstein/include/dnn/geometry/euclidean-dynamic.h270
-rw-r--r--wasserstein/include/dnn/geometry/euclidean-fixed.h196
-rw-r--r--wasserstein/include/dnn/local/kd-tree.h97
-rw-r--r--wasserstein/include/dnn/local/kd-tree.hpp330
-rw-r--r--wasserstein/include/dnn/local/search-functors.h95
-rw-r--r--wasserstein/include/dnn/parallel/tbb.h237
-rw-r--r--wasserstein/include/dnn/parallel/utils.h100
-rw-r--r--wasserstein/include/dnn/utils.h47
8 files changed, 1372 insertions, 0 deletions
diff --git a/wasserstein/include/dnn/geometry/euclidean-dynamic.h b/wasserstein/include/dnn/geometry/euclidean-dynamic.h
new file mode 100644
index 0000000..b003906
--- /dev/null
+++ b/wasserstein/include/dnn/geometry/euclidean-dynamic.h
@@ -0,0 +1,270 @@
+#ifndef DNN_GEOMETRY_EUCLIDEAN_DYNAMIC_H
+#define DNN_GEOMETRY_EUCLIDEAN_DYNAMIC_H
+
+#include <vector>
+#include <algorithm>
+#include <boost/iterator/iterator_facade.hpp>
+#include <boost/serialization/access.hpp>
+#include <boost/serialization/vector.hpp>
+#include <cmath>
+
+#include "hera_infinity.h"
+
+namespace hera
+{
+namespace ws
+{
+namespace dnn
+{
+
+template<class Real_>
+class DynamicPointVector
+{
+ public:
+ using Real = Real_;
+ struct PointType
+ {
+ void* p;
+
+ Real& operator[](const int i)
+ {
+ return (static_cast<Real*>(p))[i];
+ }
+
+ const Real& operator[](const int i) const
+ {
+ return (static_cast<Real*>(p))[i];
+ }
+
+ };
+ struct iterator;
+ typedef iterator const_iterator;
+
+ public:
+ DynamicPointVector(size_t point_capacity = 0):
+ point_capacity_(point_capacity) {}
+
+
+ PointType operator[](size_t i) const { return {(void*) &storage_[i*point_capacity_]}; }
+ inline void push_back(PointType p);
+
+ inline iterator begin();
+ inline iterator end();
+ inline const_iterator begin() const;
+ inline const_iterator end() const;
+
+ size_t size() const { return storage_.size() / point_capacity_; }
+
+ void clear() { storage_.clear(); }
+ void swap(DynamicPointVector& other) { storage_.swap(other.storage_); std::swap(point_capacity_, other.point_capacity_); }
+ void reserve(size_t sz) { storage_.reserve(sz * point_capacity_); }
+ void resize(size_t sz) { storage_.resize(sz * point_capacity_); }
+
+ private:
+ size_t point_capacity_;
+ std::vector<char> storage_;
+
+ private:
+ friend class boost::serialization::access;
+
+ template<class Archive>
+ void serialize(Archive& ar, const unsigned int version) { ar & point_capacity_ & storage_; }
+};
+
+template<typename Real>
+struct DynamicPointTraits
+{
+ typedef DynamicPointVector<Real> PointContainer;
+ typedef typename PointContainer::PointType PointType;
+ struct PointHandle
+ {
+ void* p;
+ bool operator==(const PointHandle& other) const { return p == other.p; }
+ bool operator!=(const PointHandle& other) const { return !(*this == other); }
+ bool operator<(const PointHandle& other) const { return p < other.p; }
+ bool operator>(const PointHandle& other) const { return p > other.p; }
+ };
+
+ typedef Real Coordinate;
+ typedef Real DistanceType;
+
+ DynamicPointTraits(unsigned dim = 0):
+ dim_(dim) {}
+
+ DistanceType distance(PointType p1, PointType p2) const
+ {
+ Real result = 0.0;
+ if (hera::is_infinity(internal_p)) {
+ // max norm
+ for (unsigned i = 0; i < dimension(); ++i)
+ result = std::max(result, fabs(coordinate(p1,i) - coordinate(p2,i)));
+ } else if (internal_p == Real(1.0)) {
+ // l1-norm
+ for (unsigned i = 0; i < dimension(); ++i)
+ result += fabs(coordinate(p1,i) - coordinate(p2,i));
+ } else if (internal_p == Real(2.0)) {
+ result = sqrt(sq_distance(p1,p2));
+ } else {
+ assert(internal_p > 1.0);
+ for (unsigned i = 0; i < dimension(); ++i)
+ result += std::pow(fabs(coordinate(p1,i) - coordinate(p2,i)), internal_p);
+ result = std::pow(result, Real(1.0) / internal_p);
+ }
+ return result;
+ }
+ DistanceType distance(PointHandle p1, PointHandle p2) const { return distance(PointType({p1.p}), PointType({p2.p})); }
+ DistanceType sq_distance(PointType p1, PointType p2) const { Real res = 0; for (unsigned i = 0; i < dimension(); ++i) { Real c1 = coordinate(p1,i), c2 = coordinate(p2,i); res += (c1 - c2)*(c1 - c2); } return res; }
+ DistanceType sq_distance(PointHandle p1, PointHandle p2) const { return sq_distance(PointType({p1.p}), PointType({p2.p})); }
+ unsigned dimension() const { return dim_; }
+ Real& coordinate(PointType p, unsigned i) const { return ((Real*) p.p)[i]; }
+ Real& coordinate(PointHandle h, unsigned i) const { return ((Real*) h.p)[i]; }
+
+ // it's non-standard to return a reference, but we can rely on it for code that assumes this particular point type
+ size_t& id(PointType p) const { return *((size_t*) ((Real*) p.p + dimension())); }
+ size_t& id(PointHandle h) const { return *((size_t*) ((Real*) h.p + dimension())); }
+ PointHandle handle(PointType p) const { return {p.p}; }
+ PointType point(PointHandle h) const { return {h.p}; }
+
+ void swap(PointType p1, PointType p2) const { std::swap_ranges((char*) p1.p, ((char*) p1.p) + capacity(), (char*) p2.p); }
+ bool cmp(PointType p1, PointType p2) const { return std::lexicographical_compare((Real*) p1.p, ((Real*) p1.p) + dimension(), (Real*) p2.p, ((Real*) p2.p) + dimension()); }
+ bool eq(PointType p1, PointType p2) const { return std::equal((Real*) p1.p, ((Real*) p1.p) + dimension(), (Real*) p2.p); }
+
+ // non-standard, and possibly a weird name
+ size_t capacity() const { return sizeof(Real)*dimension() + sizeof(size_t); }
+
+ PointContainer container(size_t n = 0) const { PointContainer c(capacity()); c.resize(n); return c; }
+ PointContainer container(size_t n, const PointType& p) const;
+
+ typename PointContainer::iterator
+ iterator(PointContainer& c, PointHandle ph) const;
+ typename PointContainer::const_iterator
+ iterator(const PointContainer& c, PointHandle ph) const;
+
+ Real internal_p;
+
+ private:
+ unsigned dim_;
+
+ private:
+ friend class boost::serialization::access;
+
+ template<class Archive>
+ void serialize(Archive& ar, const unsigned int version) { ar & dim_; }
+};
+
+} // dnn
+
+template<class Real>
+struct dnn::DynamicPointVector<Real>::iterator:
+ public boost::iterator_facade<iterator,
+ PointType,
+ std::random_access_iterator_tag,
+ PointType,
+ std::ptrdiff_t>
+{
+ typedef boost::iterator_facade<iterator,
+ PointType,
+ std::random_access_iterator_tag,
+ PointType,
+ std::ptrdiff_t> Parent;
+
+
+ public:
+ typedef typename Parent::value_type value_type;
+ typedef typename Parent::difference_type difference_type;
+ typedef typename Parent::reference reference;
+
+ iterator(size_t point_capacity = 0):
+ point_capacity_(point_capacity) {}
+
+ iterator(void* p, size_t point_capacity):
+ p_(p), point_capacity_(point_capacity) {}
+
+ private:
+ void increment() { p_ = ((char*) p_) + point_capacity_; }
+ void decrement() { p_ = ((char*) p_) - point_capacity_; }
+ void advance(difference_type n) { p_ = ((char*) p_) + n*point_capacity_; }
+ difference_type
+ distance_to(iterator other) const { return (((char*) other.p_) - ((char*) p_))/(int) point_capacity_; }
+ bool equal(const iterator& other) const { return p_ == other.p_; }
+ reference dereference() const { return {p_}; }
+
+ friend class ::boost::iterator_core_access;
+
+ private:
+ void* p_;
+ size_t point_capacity_;
+};
+
+template<class Real>
+void dnn::DynamicPointVector<Real>::push_back(PointType p)
+{
+ if (storage_.capacity() < storage_.size() + point_capacity_)
+ storage_.reserve(1.5*storage_.capacity());
+
+ storage_.resize(storage_.size() + point_capacity_);
+
+ std::copy((char*) p.p, (char*) p.p + point_capacity_, storage_.end() - point_capacity_);
+}
+
+template<class Real>
+typename dnn::DynamicPointVector<Real>::iterator dnn::DynamicPointVector<Real>::begin() { return iterator((void*) &*storage_.begin(), point_capacity_); }
+
+template<class Real>
+typename dnn::DynamicPointVector<Real>::iterator dnn::DynamicPointVector<Real>::end() { return iterator((void*) &*storage_.end(), point_capacity_); }
+
+template<class Real>
+typename dnn::DynamicPointVector<Real>::const_iterator dnn::DynamicPointVector<Real>::begin() const { return const_iterator((void*) &*storage_.begin(), point_capacity_); }
+
+template<class Real>
+typename dnn::DynamicPointVector<Real>::const_iterator dnn::DynamicPointVector<Real>::end() const { return const_iterator((void*) &*storage_.end(), point_capacity_); }
+
+template<typename R>
+typename dnn::DynamicPointTraits<R>::PointContainer
+dnn::DynamicPointTraits<R>::container(size_t n, const PointType& p) const
+{
+ PointContainer c = container(n);
+ for (auto x : c)
+ std::copy((char*) p.p, (char*) p.p + capacity(), (char*) x.p);
+ return c;
+}
+
+template<typename R>
+typename dnn::DynamicPointTraits<R>::PointContainer::iterator
+dnn::DynamicPointTraits<R>::iterator(PointContainer& c, PointHandle ph) const
+{ return typename PointContainer::iterator(ph.p, capacity()); }
+
+template<typename R>
+typename dnn::DynamicPointTraits<R>::PointContainer::const_iterator
+dnn::DynamicPointTraits<R>::iterator(const PointContainer& c, PointHandle ph) const
+{ return typename PointContainer::const_iterator(ph.p, capacity()); }
+
+} // ws
+} // hera
+
+namespace std {
+ template<>
+ struct hash<typename hera::ws::dnn::DynamicPointTraits<double>::PointHandle>
+ {
+ using PointHandle = typename hera::ws::dnn::DynamicPointTraits<double>::PointHandle;
+ size_t operator()(const PointHandle& ph) const
+ {
+ return std::hash<void*>()(ph.p);
+ }
+ };
+
+ template<>
+ struct hash<typename hera::ws::dnn::DynamicPointTraits<float>::PointHandle>
+ {
+ using PointHandle = typename hera::ws::dnn::DynamicPointTraits<float>::PointHandle;
+ size_t operator()(const PointHandle& ph) const
+ {
+ return std::hash<void*>()(ph.p);
+ }
+ };
+
+
+} // std
+
+
+#endif
diff --git a/wasserstein/include/dnn/geometry/euclidean-fixed.h b/wasserstein/include/dnn/geometry/euclidean-fixed.h
new file mode 100644
index 0000000..3e38baf
--- /dev/null
+++ b/wasserstein/include/dnn/geometry/euclidean-fixed.h
@@ -0,0 +1,196 @@
+#ifndef HERA_WS_DNN_GEOMETRY_EUCLIDEAN_FIXED_H
+#define HERA_WS_DNN_GEOMETRY_EUCLIDEAN_FIXED_H
+
+#include <boost/operators.hpp>
+#include <boost/array.hpp>
+#include <boost/range/value_type.hpp>
+#include <boost/serialization/access.hpp>
+#include <boost/serialization/base_object.hpp>
+
+//#include <iostream>
+#include <fstream>
+#include <string>
+#include <sstream>
+#include <cmath>
+
+#include "../parallel/tbb.h" // for dnn::vector<...>
+
+namespace hera
+{
+namespace ws
+{
+namespace dnn
+{
+ // TODO: wrap in another namespace (e.g., euclidean)
+
+ template<size_t D, typename Real = double>
+ struct Point:
+ boost::addable< Point<D,Real>,
+ boost::subtractable< Point<D,Real>,
+ boost::dividable2< Point<D, Real>, Real,
+ boost::multipliable2< Point<D, Real>, Real > > > >,
+ public boost::array<Real, D>
+ {
+ public:
+ typedef Real Coordinate;
+ typedef Real DistanceType;
+
+
+ public:
+ Point(size_t id = 0): id_(id) {}
+ template<size_t DD>
+ Point(const Point<DD,Real>& p, size_t id = 0):
+ id_(id) { *this = p; }
+
+ static size_t dimension() { return D; }
+
+ // Assign a point of different dimension
+ template<size_t DD>
+ Point& operator=(const Point<DD,Real>& p) { for (size_t i = 0; i < (D < DD ? D : DD); ++i) (*this)[i] = p[i]; if (DD < D) for (size_t i = DD; i < D; ++i) (*this)[i] = 0; return *this; }
+
+ Point& operator+=(const Point& p) { for (size_t i = 0; i < D; ++i) (*this)[i] += p[i]; return *this; }
+ Point& operator-=(const Point& p) { for (size_t i = 0; i < D; ++i) (*this)[i] -= p[i]; return *this; }
+ Point& operator/=(Real r) { for (size_t i = 0; i < D; ++i) (*this)[i] /= r; return *this; }
+ Point& operator*=(Real r) { for (size_t i = 0; i < D; ++i) (*this)[i] *= r; return *this; }
+
+ Real norm2() const { Real n = 0; for (size_t i = 0; i < D; ++i) n += (*this)[i] * (*this)[i]; return n; }
+ Real max_norm() const
+ {
+ Real res = std::fabs((*this)[0]);
+ for (size_t i = 1; i < D; ++i)
+ if (std::fabs((*this)[i]) > res)
+ res = std::fabs((*this)[i]);
+ return res;
+ }
+
+ Real l1_norm() const
+ {
+ Real res = std::fabs((*this)[0]);
+ for (size_t i = 1; i < D; ++i)
+ res += std::fabs((*this)[i]);
+ return res;
+ }
+
+ Real lp_norm(const Real p) const
+ {
+ assert( !std::isinf(p) );
+ if ( p == 1.0 )
+ return l1_norm();
+ Real res = std::pow(std::fabs((*this)[0]), p);
+ for (size_t i = 1; i < D; ++i)
+ res += std::pow(std::fabs((*this)[i]), p);
+ return std::pow(res, 1.0 / p);
+ }
+
+ // quick and dirty for now; make generic later
+ //DistanceType distance(const Point& other) const { return sqrt(sq_distance(other)); }
+ //DistanceType sq_distance(const Point& other) const { return (other - *this).norm2(); }
+
+ DistanceType distance(const Point& other) const { return (other - *this).max_norm(); }
+ DistanceType p_distance(const Point& other, const double p) const { return (other - *this).lp_norm(p); }
+
+ size_t id() const { return id_; }
+ size_t& id() { return id_; }
+
+ private:
+ friend class boost::serialization::access;
+
+ template<class Archive>
+ void serialize(Archive& ar, const unsigned int version) { ar & boost::serialization::base_object< boost::array<Real,D> >(*this) & id_; }
+
+ private:
+ size_t id_;
+ };
+
+ template<size_t D, typename Real>
+ std::ostream&
+ operator<<(std::ostream& out, const Point<D,Real>& p)
+ { out << p[0]; for (size_t i = 1; i < D; ++i) out << " " << p[i]; return out; }
+
+
+ template<class Point>
+ struct PointTraits; // intentionally undefined; should be specialized for each type
+
+
+ template<size_t D, typename Real>
+ struct PointTraits< Point<D, Real> > // specialization for dnn::Point
+ {
+ typedef Point<D,Real> PointType;
+ typedef const PointType* PointHandle;
+ typedef std::vector<PointType> PointContainer;
+
+ typedef typename PointType::Coordinate Coordinate;
+ typedef typename PointType::DistanceType DistanceType;
+
+
+ static DistanceType
+ distance(const PointType& p1, const PointType& p2) { if (hera::is_infinity(internal_p)) return p1.distance(p2); else return p1.p_distance(p2, internal_p); }
+
+ static DistanceType
+ distance(PointHandle p1, PointHandle p2) { return distance(*p1,*p2); }
+
+ static size_t dimension() { return D; }
+ static Real coordinate(const PointType& p, size_t i) { return p[i]; }
+ static Real& coordinate(PointType& p, size_t i) { return p[i]; }
+ static Real coordinate(PointHandle p, size_t i) { return coordinate(*p,i); }
+
+ static size_t id(const PointType& p) { return p.id(); }
+ static size_t& id(PointType& p) { return p.id(); }
+ static size_t id(PointHandle p) { return id(*p); }
+
+ static PointHandle
+ handle(const PointType& p) { return &p; }
+ static const PointType&
+ point(PointHandle ph) { return *ph; }
+
+ void swap(PointType& p1, PointType& p2) const { return std::swap(p1, p2); }
+
+ static PointContainer
+ container(size_t n = 0, const PointType& p = PointType()) { return PointContainer(n, p); }
+ static typename PointContainer::iterator
+ iterator(PointContainer& c, PointHandle ph) { return c.begin() + (ph - &c[0]); }
+ static typename PointContainer::const_iterator
+ iterator(const PointContainer& c, PointHandle ph) { return c.begin() + (ph - &c[0]); }
+
+ // Internal_p determines which norm will be used in Wasserstein metric (not to
+ // be confused with wassersteinPower parameter:
+ // we raise \| p - q \|_{internal_p} to wassersteinPower.
+ static Real internal_p;
+
+ private:
+
+ friend class boost::serialization::access;
+
+ template<class Archive>
+ void serialize(Archive& ar, const unsigned int version) {}
+
+ };
+
+ template<size_t D, typename Real>
+ Real PointTraits< Point<D, Real> >::internal_p = hera::get_infinity<Real>();
+
+
+ template<class PointContainer>
+ void read_points(const std::string& filename, PointContainer& points)
+ {
+ typedef typename boost::range_value<PointContainer>::type Point;
+ typedef typename PointTraits<Point>::Coordinate Coordinate;
+
+ std::ifstream in(filename.c_str());
+ std::string line;
+ while(std::getline(in, line))
+ {
+ if (line[0] == '#') continue; // comment line in the file
+ std::stringstream linestream(line);
+ Coordinate x;
+ points.push_back(Point());
+ size_t i = 0;
+ while (linestream >> x)
+ points.back()[i++] = x;
+ }
+ }
+} // dnn
+} // ws
+} // hera
+
+#endif
diff --git a/wasserstein/include/dnn/local/kd-tree.h b/wasserstein/include/dnn/local/kd-tree.h
new file mode 100644
index 0000000..8e52a5c
--- /dev/null
+++ b/wasserstein/include/dnn/local/kd-tree.h
@@ -0,0 +1,97 @@
+#ifndef HERA_WS_DNN_LOCAL_KD_TREE_H
+#define HERA_WS_DNN_LOCAL_KD_TREE_H
+
+#include "../utils.h"
+#include "search-functors.h"
+
+#include <unordered_map>
+
+#include <boost/tuple/tuple.hpp>
+#include <boost/shared_ptr.hpp>
+#include <boost/range/value_type.hpp>
+
+#include <boost/static_assert.hpp>
+#include <boost/type_traits.hpp>
+
+namespace hera
+{
+namespace ws
+{
+namespace dnn
+{
+ // Weighted KDTree
+ // Traits_ provides Coordinate, DistanceType, PointType, dimension(), distance(p1,p2), coordinate(p,i)
+ template< class Traits_ >
+ class KDTree
+ {
+ public:
+ typedef Traits_ Traits;
+ typedef dnn::HandleDistance<KDTree> HandleDistance;
+
+ typedef typename Traits::PointType Point;
+ typedef typename Traits::PointHandle PointHandle;
+ typedef typename Traits::Coordinate Coordinate;
+ typedef typename Traits::DistanceType DistanceType;
+ typedef std::vector<PointHandle> HandleContainer;
+ typedef std::vector<HandleDistance> HDContainer; // TODO: use tbb::scalable_allocator
+ typedef HDContainer Result;
+ typedef std::vector<DistanceType> DistanceContainer;
+ typedef std::unordered_map<PointHandle, size_t> HandleMap;
+
+ BOOST_STATIC_ASSERT_MSG(has_coordinates<Traits, PointHandle, int>::value, "KDTree requires coordinates");
+
+ public:
+ KDTree(const Traits& traits):
+ traits_(traits) {}
+
+ KDTree(const Traits& traits, HandleContainer&& handles, double _wassersteinPower = 1.0);
+
+ template<class Range>
+ KDTree(const Traits& traits, const Range& range, double _wassersteinPower = 1.0);
+
+ template<class Range>
+ void init(const Range& range);
+
+ DistanceType weight(PointHandle p) { return weights_[indices_[p]]; }
+ void change_weight(PointHandle p, DistanceType w);
+ void adjust_weights(DistanceType delta); // subtract delta from all weights
+
+ HandleDistance find(PointHandle q) const;
+ Result findR(PointHandle q, DistanceType r) const; // all neighbors within r
+ Result findK(PointHandle q, size_t k) const; // k nearest neighbors
+
+ HandleDistance find(const Point& q) const { return find(traits().handle(q)); }
+ Result findR(const Point& q, DistanceType r) const { return findR(traits().handle(q), r); }
+ Result findK(const Point& q, size_t k) const { return findK(traits().handle(q), k); }
+
+ template<class ResultsFunctor>
+ void search(PointHandle q, ResultsFunctor& rf) const;
+
+ const Traits& traits() const { return traits_; }
+
+ void printWeights(void);
+
+ private:
+ void init();
+
+ typedef typename HandleContainer::iterator HCIterator;
+ typedef std::tuple<HCIterator, HCIterator, size_t> KDTreeNode;
+
+ struct CoordinateComparison;
+ struct OrderTree;
+
+ private:
+ Traits traits_;
+ HandleContainer tree_;
+ DistanceContainer weights_; // point weight
+ DistanceContainer subtree_weights_; // min weight in the subtree
+ HandleMap indices_;
+ double wassersteinPower;
+ };
+} // dnn
+} // ws
+} // hera
+
+#include "kd-tree.hpp"
+
+#endif
diff --git a/wasserstein/include/dnn/local/kd-tree.hpp b/wasserstein/include/dnn/local/kd-tree.hpp
new file mode 100644
index 0000000..bdeef45
--- /dev/null
+++ b/wasserstein/include/dnn/local/kd-tree.hpp
@@ -0,0 +1,330 @@
+#include <boost/range/counting_range.hpp>
+#include <boost/range/algorithm_ext/push_back.hpp>
+#include <boost/range.hpp>
+
+#include <queue>
+#include <stack>
+
+#include "../parallel/tbb.h"
+#include "def_debug_ws.h"
+
+template<class T>
+hera::ws::dnn::KDTree<T>::
+KDTree(const Traits& traits, HandleContainer&& handles, double _wassersteinPower):
+ traits_(traits), tree_(std::move(handles)), wassersteinPower(_wassersteinPower)
+{ assert(wassersteinPower >= 1.0); init(); }
+
+template<class T>
+template<class Range>
+hera::ws::dnn::KDTree<T>::
+KDTree(const Traits& traits, const Range& range, double _wassersteinPower):
+ traits_(traits), wassersteinPower(_wassersteinPower)
+{
+ assert( wassersteinPower >= 1.0);
+ init(range);
+}
+
+template<class T>
+template<class Range>
+void
+hera::ws::dnn::KDTree<T>::
+init(const Range& range)
+{
+ size_t sz = std::distance(std::begin(range), std::end(range));
+ tree_.reserve(sz);
+ weights_.resize(sz, 0);
+ subtree_weights_.resize(sz, 0);
+ for (PointHandle h : range)
+ tree_.push_back(h);
+ init();
+}
+
+template<class T>
+void
+hera::ws::dnn::KDTree<T>::
+init()
+{
+ if (tree_.empty())
+ return;
+
+#if defined(TBB)
+ task_group g;
+ g.run(OrderTree(tree_.begin(), tree_.end(), 0, traits()));
+ g.wait();
+#else
+ OrderTree(tree_.begin(), tree_.end(), 0, traits()).serial();
+#endif
+
+ for (size_t i = 0; i < tree_.size(); ++i)
+ indices_[tree_[i]] = i;
+}
+
+template<class T>
+struct
+hera::ws::dnn::KDTree<T>::OrderTree
+{
+ OrderTree(HCIterator b_, HCIterator e_, size_t i_, const Traits& traits_):
+ b(b_), e(e_), i(i_), traits(traits_) {}
+
+ void operator()() const
+ {
+ if (e - b < 1000)
+ {
+ serial();
+ return;
+ }
+
+ HCIterator m = b + (e - b)/2;
+ CoordinateComparison cmp(i, traits);
+ std::nth_element(b,m,e, cmp);
+ size_t next_i = (i + 1) % traits.dimension();
+
+ task_group g;
+ if (b < m - 1) g.run(OrderTree(b, m, next_i, traits));
+ if (e > m + 2) g.run(OrderTree(m+1, e, next_i, traits));
+ g.wait();
+ }
+
+ void serial() const
+ {
+ std::queue<KDTreeNode> q;
+ q.push(KDTreeNode(b,e,i));
+ while (!q.empty())
+ {
+ HCIterator b, e; size_t i;
+ std::tie(b,e,i) = q.front();
+ q.pop();
+ HCIterator m = b + (e - b)/2;
+
+ CoordinateComparison cmp(i, traits);
+ std::nth_element(b,m,e, cmp);
+ size_t next_i = (i + 1) % traits.dimension();
+
+ // Replace with a size condition instead?
+ if (m - b > 1) q.push(KDTreeNode(b, m, next_i));
+ if (e - m > 2) q.push(KDTreeNode(m+1, e, next_i));
+ }
+ }
+
+ HCIterator b, e;
+ size_t i;
+ const Traits& traits;
+};
+
+template<class T>
+template<class ResultsFunctor>
+void
+hera::ws::dnn::KDTree<T>::
+search(PointHandle q, ResultsFunctor& rf) const
+{
+ typedef typename HandleContainer::const_iterator HCIterator;
+ typedef std::tuple<HCIterator, HCIterator, size_t> KDTreeNode;
+
+ if (tree_.empty())
+ return;
+
+ DistanceType D = std::numeric_limits<DistanceType>::max();
+
+ // TODO: use tbb::scalable_allocator for the queue
+ std::queue<KDTreeNode> nodes;
+
+ nodes.push(KDTreeNode(tree_.begin(), tree_.end(), 0));
+
+ while (!nodes.empty())
+ {
+ HCIterator b, e; size_t i;
+ std::tie(b,e,i) = nodes.front();
+ nodes.pop();
+
+ CoordinateComparison cmp(i, traits());
+ i = (i + 1) % traits().dimension();
+
+ HCIterator m = b + (e - b)/2;
+
+ DistanceType dist = (wassersteinPower == 1.0) ? traits().distance(q, *m) + weights_[m - tree_.begin()] : std::pow(traits().distance(q, *m), wassersteinPower) + weights_[m - tree_.begin()];
+
+
+ D = rf(*m, dist);
+
+ // we are really searching w.r.t L_\infty ball; could prune better with an L_2 ball
+ Coordinate diff = cmp.diff(q, *m); // diff returns signed distance
+
+ DistanceType diffToWasserPower = (wassersteinPower == 1.0) ? diff : ((diff > 0 ? 1.0 : -1.0) * std::pow(fabs(diff), wassersteinPower));
+
+ size_t lm = m + 1 + (e - (m+1))/2 - tree_.begin();
+ if (e > m + 1 && diffToWasserPower - subtree_weights_[lm] >= -D) {
+ nodes.push(KDTreeNode(m+1, e, i));
+ }
+
+ size_t rm = b + (m - b) / 2 - tree_.begin();
+ if (b < m && diffToWasserPower + subtree_weights_[rm] <= D) {
+ nodes.push(KDTreeNode(b, m, i));
+ }
+ }
+}
+
+template<class T>
+void
+hera::ws::dnn::KDTree<T>::
+adjust_weights(DistanceType delta)
+{
+ for(auto& w : weights_)
+ w -= delta;
+
+ for(auto& sw : subtree_weights_)
+ sw -= delta;
+}
+
+
+template<class T>
+void
+hera::ws::dnn::KDTree<T>::
+change_weight(PointHandle p, DistanceType w)
+{
+ size_t idx = indices_[p];
+
+ if ( weights_[idx] == w ) {
+ return;
+ }
+
+ bool weight_increases = ( weights_[idx] < w );
+ weights_[idx] = w;
+
+ typedef std::tuple<HCIterator, HCIterator> KDTreeNode;
+
+ // find the path down the tree to this node
+ // not an ideal strategy, but // it's not clear how to move up from the node in general
+ std::stack<KDTreeNode> s;
+ s.push(KDTreeNode(tree_.begin(),tree_.end()));
+
+ do
+ {
+ HCIterator b,e;
+ std::tie(b,e) = s.top();
+
+ size_t im = b + (e - b)/2 - tree_.begin();
+
+ if (idx == im)
+ break;
+ else if (idx < im)
+ s.push(KDTreeNode(b, tree_.begin() + im));
+ else // idx > im
+ s.push(KDTreeNode(tree_.begin() + im + 1, e));
+ } while(1);
+
+ // update subtree_weights_ on the path to the root
+ DistanceType min_w = w;
+ while (!s.empty())
+ {
+ HCIterator b,e;
+ std::tie(b,e) = s.top();
+ HCIterator m = b + (e - b)/2;
+ size_t im = m - tree_.begin();
+ s.pop();
+
+
+ // left and right children
+ if (b < m)
+ {
+ size_t lm = b + (m - b)/2 - tree_.begin();
+ if (subtree_weights_[lm] < min_w)
+ min_w = subtree_weights_[lm];
+ }
+
+ if (e > m + 1)
+ {
+ size_t rm = m + 1 + (e - (m+1))/2 - tree_.begin();
+ if (subtree_weights_[rm] < min_w)
+ min_w = subtree_weights_[rm];
+ }
+
+ if (weights_[im] < min_w) {
+ min_w = weights_[im];
+ }
+
+ if (weight_increases) {
+
+ if (subtree_weights_[im] < min_w ) // increase weight
+ subtree_weights_[im] = min_w;
+ else
+ break;
+
+ } else {
+
+ if (subtree_weights_[im] > min_w ) // decrease weight
+ subtree_weights_[im] = min_w;
+ else
+ break;
+
+ }
+ }
+}
+
+template<class T>
+typename hera::ws::dnn::KDTree<T>::HandleDistance
+hera::ws::dnn::KDTree<T>::
+find(PointHandle q) const
+{
+ hera::ws::dnn::NNRecord<HandleDistance> nn;
+ search(q, nn);
+ return nn.result;
+}
+
+template<class T>
+typename hera::ws::dnn::KDTree<T>::Result
+hera::ws::dnn::KDTree<T>::
+findR(PointHandle q, DistanceType r) const
+{
+ hera::ws::dnn::rNNRecord<HandleDistance> rnn(r);
+ search(q, rnn);
+ std::sort(rnn.result.begin(), rnn.result.end());
+ return rnn.result;
+}
+
+template<class T>
+typename hera::ws::dnn::KDTree<T>::Result
+hera::ws::dnn::KDTree<T>::
+findK(PointHandle q, size_t k) const
+{
+ hera::ws::dnn::kNNRecord<HandleDistance> knn(k);
+ search(q, knn);
+ std::sort(knn.result.begin(), knn.result.end());
+ return knn.result;
+}
+
+
+template<class T>
+struct hera::ws::dnn::KDTree<T>::CoordinateComparison
+{
+ CoordinateComparison(size_t i, const Traits& traits):
+ i_(i), traits_(traits) {}
+
+ bool operator()(PointHandle p1, PointHandle p2) const { return coordinate(p1) < coordinate(p2); }
+ Coordinate diff(PointHandle p1, PointHandle p2) const { return coordinate(p1) - coordinate(p2); }
+
+ Coordinate coordinate(PointHandle p) const { return traits_.coordinate(p, i_); }
+ size_t axis() const { return i_; }
+
+ private:
+ size_t i_;
+ const Traits& traits_;
+};
+
+template<class T>
+void
+hera::ws::dnn::KDTree<T>::
+printWeights(void)
+{
+#ifndef FOR_R_TDA
+ std::cout << "weights_:" << std::endl;
+ for(const auto ph : indices_) {
+ std::cout << "idx = " << ph.second << ": (" << (ph.first)->at(0) << ", " << (ph.first)->at(1) << ") weight = " << weights_[ph.second] << std::endl;
+ }
+ std::cout << "subtree_weights_:" << std::endl;
+ for(size_t idx = 0; idx < subtree_weights_.size(); ++idx) {
+ std::cout << idx << " : " << subtree_weights_[idx] << std::endl;
+ }
+#endif
+}
+
+
diff --git a/wasserstein/include/dnn/local/search-functors.h b/wasserstein/include/dnn/local/search-functors.h
new file mode 100644
index 0000000..1419f22
--- /dev/null
+++ b/wasserstein/include/dnn/local/search-functors.h
@@ -0,0 +1,95 @@
+#ifndef HERA_WS_DNN_LOCAL_SEARCH_FUNCTORS_H
+#define HERA_WS_DNN_LOCAL_SEARCH_FUNCTORS_H
+
+#include <boost/range/algorithm/heap_algorithm.hpp>
+
+namespace hera
+{
+namespace ws
+{
+namespace dnn
+{
+
+template<class NN>
+struct HandleDistance
+{
+ typedef typename NN::PointHandle PointHandle;
+ typedef typename NN::DistanceType DistanceType;
+ typedef typename NN::HDContainer HDContainer;
+
+ HandleDistance() {}
+ HandleDistance(PointHandle pp, DistanceType dd):
+ p(pp), d(dd) {}
+ bool operator<(const HandleDistance& other) const { return d < other.d; }
+
+ PointHandle p;
+ DistanceType d;
+};
+
+template<class HandleDistance>
+struct NNRecord
+{
+ typedef typename HandleDistance::PointHandle PointHandle;
+ typedef typename HandleDistance::DistanceType DistanceType;
+
+ NNRecord() { result.d = std::numeric_limits<DistanceType>::max(); }
+ DistanceType operator()(PointHandle p, DistanceType d) { if (d < result.d) { result.p = p; result.d = d; } return result.d; }
+ HandleDistance result;
+};
+
+template<class HandleDistance>
+struct rNNRecord
+{
+ typedef typename HandleDistance::PointHandle PointHandle;
+ typedef typename HandleDistance::DistanceType DistanceType;
+ typedef typename HandleDistance::HDContainer HDContainer;
+
+ rNNRecord(DistanceType r_): r(r_) {}
+ DistanceType operator()(PointHandle p, DistanceType d)
+ {
+ if (d <= r)
+ result.push_back(HandleDistance(p,d));
+ return r;
+ }
+
+ DistanceType r;
+ HDContainer result;
+};
+
+template<class HandleDistance>
+struct kNNRecord
+{
+ typedef typename HandleDistance::PointHandle PointHandle;
+ typedef typename HandleDistance::DistanceType DistanceType;
+ typedef typename HandleDistance::HDContainer HDContainer;
+
+ kNNRecord(unsigned k_): k(k_) {}
+ DistanceType operator()(PointHandle p, DistanceType d)
+ {
+ if (result.size() < k)
+ {
+ result.push_back(HandleDistance(p,d));
+ boost::push_heap(result);
+ if (result.size() < k)
+ return std::numeric_limits<DistanceType>::max();
+ } else if (d < result[0].d)
+ {
+ boost::pop_heap(result);
+ result.back() = HandleDistance(p,d);
+ boost::push_heap(result);
+ }
+ if ( result.size() > 1 ) {
+ assert( result[0].d >= result[1].d );
+ }
+ return result[0].d;
+ }
+
+ unsigned k;
+ HDContainer result;
+};
+
+} // dnn
+} // ws
+} // hera
+
+#endif // DNN_LOCAL_SEARCH_FUNCTORS_H
diff --git a/wasserstein/include/dnn/parallel/tbb.h b/wasserstein/include/dnn/parallel/tbb.h
new file mode 100644
index 0000000..3f811d6
--- /dev/null
+++ b/wasserstein/include/dnn/parallel/tbb.h
@@ -0,0 +1,237 @@
+#ifndef HERA_WS_PARALLEL_H
+#define HERA_WS_PARALLEL_H
+
+#include <vector>
+
+#include <boost/range.hpp>
+#include <boost/bind.hpp>
+#include <boost/foreach.hpp>
+
+#ifdef TBB
+
+#include <tbb/tbb.h>
+#include <tbb/concurrent_hash_map.h>
+#include <tbb/scalable_allocator.h>
+
+#include <boost/serialization/split_free.hpp>
+#include <boost/serialization/collections_load_imp.hpp>
+#include <boost/serialization/collections_save_imp.hpp>
+
+namespace hera
+{
+namespace ws
+{
+namespace dnn
+{
+ using tbb::mutex;
+ using tbb::task_scheduler_init;
+ using tbb::task_group;
+ using tbb::task;
+
+ template<class T>
+ struct vector
+ {
+ typedef tbb::concurrent_vector<T> type;
+ };
+
+ template<class T>
+ struct atomic
+ {
+ typedef tbb::atomic<T> type;
+ static T compare_and_swap(type& v, T n, T o) { return v.compare_and_swap(n,o); }
+ };
+
+ template<class Iterator, class F>
+ void do_foreach(Iterator begin, Iterator end, const F& f) { tbb::parallel_do(begin, end, f); }
+
+ template<class Range, class F>
+ void for_each_range_(const Range& r, const F& f)
+ {
+ for (typename Range::iterator cur = r.begin(); cur != r.end(); ++cur)
+ f(*cur);
+ }
+
+ template<class F>
+ void for_each_range(size_t from, size_t to, const F& f)
+ {
+ //static tbb::affinity_partitioner ap;
+ //tbb::parallel_for(c.range(), boost::bind(&for_each_range_<typename Container::range_type, F>, _1, f), ap);
+ tbb::parallel_for(from, to, f);
+ }
+
+ template<class Container, class F>
+ void for_each_range(const Container& c, const F& f)
+ {
+ //static tbb::affinity_partitioner ap;
+ //tbb::parallel_for(c.range(), boost::bind(&for_each_range_<typename Container::range_type, F>, _1, f), ap);
+ tbb::parallel_for(c.range(), boost::bind(&for_each_range_<typename Container::const_range_type, F>, _1, f));
+ }
+
+ template<class Container, class F>
+ void for_each_range(Container& c, const F& f)
+ {
+ //static tbb::affinity_partitioner ap;
+ //tbb::parallel_for(c.range(), boost::bind(&for_each_range_<typename Container::range_type, F>, _1, f), ap);
+ tbb::parallel_for(c.range(), boost::bind(&for_each_range_<typename Container::range_type, F>, _1, f));
+ }
+
+ template<class ID, class NodePointer, class IDTraits, class Allocator>
+ struct map_traits
+ {
+ typedef tbb::concurrent_hash_map<ID, NodePointer, IDTraits, Allocator> type;
+ typedef typename type::range_type range;
+ };
+
+ struct progress_timer
+ {
+ progress_timer(): start(tbb::tick_count::now()) {}
+ ~progress_timer()
+ { std::cout << (tbb::tick_count::now() - start).seconds() << " s" << std::endl; }
+
+ tbb::tick_count start;
+ };
+} // dnn
+} // ws
+} // hera
+
+// Serialization for tbb::concurrent_vector<...>
+namespace boost
+{
+ namespace serialization
+ {
+ template<class Archive, class T, class A>
+ void save(Archive& ar, const tbb::concurrent_vector<T,A>& v, const unsigned int file_version)
+ { stl::save_collection(ar, v); }
+
+ template<class Archive, class T, class A>
+ void load(Archive& ar, tbb::concurrent_vector<T,A>& v, const unsigned int file_version)
+ {
+ stl::load_collection<Archive,
+ tbb::concurrent_vector<T,A>,
+ stl::archive_input_seq< Archive, tbb::concurrent_vector<T,A> >,
+ stl::reserve_imp< tbb::concurrent_vector<T,A> >
+ >(ar, v);
+ }
+
+ template<class Archive, class T, class A>
+ void serialize(Archive& ar, tbb::concurrent_vector<T,A>& v, const unsigned int file_version)
+ { split_free(ar, v, file_version); }
+
+ template<class Archive, class T>
+ void save(Archive& ar, const tbb::atomic<T>& v, const unsigned int file_version)
+ { T v_ = v; ar << v_; }
+
+ template<class Archive, class T>
+ void load(Archive& ar, tbb::atomic<T>& v, const unsigned int file_version)
+ { T v_; ar >> v_; v = v_; }
+
+ template<class Archive, class T>
+ void serialize(Archive& ar, tbb::atomic<T>& v, const unsigned int file_version)
+ { split_free(ar, v, file_version); }
+ }
+}
+
+#else
+
+#include <algorithm>
+#include <map>
+#include <boost/progress.hpp>
+
+namespace hera
+{
+namespace ws
+{
+namespace dnn
+{
+ template<class T>
+ struct vector
+ {
+ typedef ::std::vector<T> type;
+ };
+
+ template<class T>
+ struct atomic
+ {
+ typedef T type;
+ static T compare_and_swap(type& v, T n, T o) { if (v != o) return v; v = n; return o; }
+ };
+
+ template<class Iterator, class F>
+ void do_foreach(Iterator begin, Iterator end, const F& f) { std::for_each(begin, end, f); }
+
+ template<class F>
+ void for_each_range(size_t from, size_t to, const F& f)
+ {
+ for (size_t i = from; i < to; ++i)
+ f(i);
+ }
+
+ template<class Container, class F>
+ void for_each_range(Container& c, const F& f)
+ {
+ BOOST_FOREACH(const typename Container::value_type& i, c)
+ f(i);
+ }
+
+ template<class Container, class F>
+ void for_each_range(const Container& c, const F& f)
+ {
+ BOOST_FOREACH(const typename Container::value_type& i, c)
+ f(i);
+ }
+
+ struct mutex
+ {
+ struct scoped_lock
+ {
+ scoped_lock() {}
+ scoped_lock(mutex& ) {}
+ void acquire(mutex& ) const {}
+ void release() const {}
+ };
+ };
+
+ struct task_scheduler_init
+ {
+ task_scheduler_init(unsigned) {}
+ void initialize(unsigned) {}
+ static const unsigned automatic = 0;
+ static const unsigned deferred = 0;
+ };
+
+ struct task_group
+ {
+ template<class Functor>
+ void run(const Functor& f) const { f(); }
+ void wait() const {}
+ };
+
+ template<class ID, class NodePointer, class IDTraits, class Allocator>
+ struct map_traits
+ {
+ typedef std::map<ID, NodePointer,
+ typename IDTraits::Comparison,
+ Allocator> type;
+ typedef type range;
+ };
+
+ using boost::progress_timer;
+} // dnn
+} // ws
+} // hera
+
+#endif // TBB
+
+namespace hera
+{
+namespace ws
+{
+namespace dnn
+{
+ template<class Range, class F>
+ void do_foreach(const Range& range, const F& f) { do_foreach(boost::begin(range), boost::end(range), f); }
+} // dnn
+} // ws
+} // hera
+
+#endif
diff --git a/wasserstein/include/dnn/parallel/utils.h b/wasserstein/include/dnn/parallel/utils.h
new file mode 100644
index 0000000..7104ec3
--- /dev/null
+++ b/wasserstein/include/dnn/parallel/utils.h
@@ -0,0 +1,100 @@
+#ifndef HERA_WS_PARALLEL_UTILS_H
+#define HERA_WS_PARALLEL_UTILS_H
+
+#include "../utils.h"
+
+namespace hera
+{
+namespace ws
+{
+namespace dnn
+{
+ // Assumes rng is synchronized across ranks
+ template<class DataVector, class RNGType, class SwapFunctor>
+ void shuffle(mpi::communicator& world, DataVector& data, RNGType& rng, const SwapFunctor& swap, DataVector empty = DataVector());
+
+ template<class DataVector, class RNGType>
+ void shuffle(mpi::communicator& world, DataVector& data, RNGType& rng)
+ {
+ typedef decltype(data[0]) T;
+ shuffle(world, data, rng, [](T& x, T& y) { std::swap(x,y); });
+ }
+} // dnn
+} // ws
+} // hera
+
+template<class DataVector, class RNGType, class SwapFunctor>
+void
+hera::ws::dnn::shuffle(mpi::communicator& world, DataVector& data, RNGType& rng, const SwapFunctor& swap, DataVector empty)
+{
+ // This is not a perfect shuffle: it dishes out data in chunks of 1/size.
+ // (It can be interpreted as generating a bistochastic matrix by taking the
+ // sum of size random permutation matrices.) Hopefully, it works for our purposes.
+
+ typedef typename RNGType::result_type RNGResult;
+
+ int size = world.size();
+ int rank = world.rank();
+
+ // Generate local seeds
+ boost::uniform_int<RNGResult> uniform;
+ RNGResult seed;
+ for (size_t i = 0; i < size; ++i)
+ {
+ RNGResult v = uniform(rng);
+ if (i == rank)
+ seed = v;
+ }
+ RNGType local_rng(seed);
+
+ // Shuffle local data
+ hera::ws::dnn::random_shuffle(data.begin(), data.end(), local_rng, swap);
+
+ // Decide how much of our data goes to i-th processor
+ std::vector<size_t> out_counts(size);
+ std::vector<int> ranks(boost::counting_iterator<int>(0),
+ boost::counting_iterator<int>(size));
+ for (size_t i = 0; i < size; ++i)
+ {
+ hera::ws::dnn::random_shuffle(ranks.begin(), ranks.end(), rng);
+ ++out_counts[ranks[rank]];
+ }
+
+ // Fill the outgoing array
+ size_t total = 0;
+ std::vector< DataVector > outgoing(size, empty);
+ for (size_t i = 0; i < size; ++i)
+ {
+ size_t count = data.size()*out_counts[i]/size;
+ if (total + count > data.size())
+ count = data.size() - total;
+
+ outgoing[i].reserve(count);
+ for (size_t j = total; j < total + count; ++j)
+ outgoing[i].push_back(data[j]);
+
+ total += count;
+ }
+
+ boost::uniform_int<size_t> uniform_outgoing(0,size-1); // in range [0,size-1]
+ while(total < data.size()) // send leftover to random processes
+ {
+ outgoing[uniform_outgoing(local_rng)].push_back(data[total]);
+ ++total;
+ }
+ data.clear();
+
+ // Exchange the data
+ std::vector< DataVector > incoming(size, empty);
+ mpi::all_to_all(world, outgoing, incoming);
+ outgoing.clear();
+
+ // Assemble our data
+ for(const DataVector& vec : incoming)
+ for (size_t i = 0; i < vec.size(); ++i)
+ data.push_back(vec[i]);
+ hera::ws::dnn::random_shuffle(data.begin(), data.end(), local_rng, swap);
+ // XXX: the final shuffle is irrelevant for our purposes. But it's also cheap.
+}
+
+#endif
diff --git a/wasserstein/include/dnn/utils.h b/wasserstein/include/dnn/utils.h
new file mode 100644
index 0000000..bbce793
--- /dev/null
+++ b/wasserstein/include/dnn/utils.h
@@ -0,0 +1,47 @@
+#ifndef HERA_WS_DNN_UTILS_H
+#define HERA_WS_DNN_UTILS_H
+
+#include <boost/random/uniform_int.hpp>
+#include <boost/foreach.hpp>
+#include <boost/typeof/typeof.hpp>
+
+namespace hera
+{
+namespace ws
+{
+namespace dnn
+{
+
+template <typename T, typename... Args>
+struct has_coordinates
+{
+ template <typename C, typename = decltype( std::declval<C>().coordinate(std::declval<Args>()...) )>
+ static std::true_type test(int);
+
+ template <typename C>
+ static std::false_type test(...);
+
+ static constexpr bool value = decltype(test<T>(0))::value;
+};
+
+template<class RandomIt, class UniformRandomNumberGenerator, class SwapFunctor>
+void random_shuffle(RandomIt first, RandomIt last, UniformRandomNumberGenerator& g, const SwapFunctor& swap)
+{
+ size_t n = last - first;
+ boost::uniform_int<size_t> uniform(0,n);
+ for (size_t i = n-1; i > 0; --i)
+ swap(first[i], first[uniform(g,i+1)]); // picks a random number in [0,i] range
+}
+
+template<class RandomIt, class UniformRandomNumberGenerator>
+void random_shuffle(RandomIt first, RandomIt last, UniformRandomNumberGenerator& g)
+{
+ typedef decltype(*first) T;
+ random_shuffle(first, last, g, [](T& x, T& y) { std::swap(x,y); });
+}
+
+} // dnn
+} // ws
+} // hera
+
+#endif