#pragma once #include #include #include #include #include #include namespace kdbush { template struct nth { inline static typename std::tuple_element::type get(const T &t) { return std::get(t); } }; template class KDBush { public: using TNumber = decltype(nth<0, TPoint>::get(std::declval())); static_assert( std::is_same::get(std::declval()))>::value, "point component types must be identical"); static const std::uint8_t defaultNodeSize = 64; KDBush(const std::uint8_t nodeSize_ = defaultNodeSize) : nodeSize(nodeSize_) { } KDBush(const std::vector &points_, const std::uint8_t nodeSize_ = defaultNodeSize) : KDBush(std::begin(points_), std::end(points_), nodeSize_) { } template KDBush(const TPointIter &points_begin, const TPointIter &points_end, const std::uint8_t nodeSize_ = defaultNodeSize) : nodeSize(nodeSize_) { fill(points_begin, points_end); } void fill(const std::vector &points_) { fill(std::begin(points_), std::end(points_)); } template void fill(const TPointIter &points_begin, const TPointIter &points_end) { assert(points.empty()); const TIndex size = static_cast(std::distance(points_begin, points_end)); if (size == 0) return; points.reserve(size); //ids.reserve(size); TIndex i = 0; for (auto p = points_begin; p != points_end; p++) { points.emplace_back(nth<0, TPoint>::get(*p), nth<1, TPoint>::get(*p)); //ids.push_back(i++); } sortKD(0, size - 1, 0); } template void range(const TNumber minX, const TNumber minY, const TNumber maxX, const TNumber maxY, const TVisitor &visitor) const { range(minX, minY, maxX, maxY, visitor, 0, static_cast(points.size() - 1), 0); } template void within(const TNumber qx, const TNumber qy, const TNumber r, const TVisitor &visitor) const { within(qx, qy, r, visitor, 0, static_cast(points.size() - 1), 0); } protected: //std::vector ids; std::vector points; const std::uint8_t nodeSize; template void range(const TNumber minX, const TNumber minY, const TNumber maxX, const TNumber maxY, const TVisitor &visitor, const TIndex left, const TIndex right, const std::uint8_t axis) const { if ( points.empty() ) return; if (right - left <= nodeSize) { for (auto i = left; i <= right; i++) { const TNumber x = std::get<0>(points[i].coords); const TNumber y = std::get<1>(points[i].coords); if (x >= minX && x <= maxX && y >= minY && y <= maxY) visitor(points[i]); } return; } const TIndex m = (left + right) >> 1; const TNumber x = std::get<0>(points[m].coords); const TNumber y = std::get<1>(points[m].coords); if (x >= minX && x <= maxX && y >= minY && y <= maxY) visitor(points[m]); if (axis == 0 ? minX <= x : minY <= y) range(minX, minY, maxX, maxY, visitor, left, m - 1, (axis + 1) % 2); if (axis == 0 ? maxX >= x : maxY >= y) range(minX, minY, maxX, maxY, visitor, m + 1, right, (axis + 1) % 2); } template void within(const TNumber qx, const TNumber qy, const TNumber r, const TVisitor &visitor, const TIndex left, const TIndex right, const std::uint8_t axis) const { if ( points.empty() ) return; const TNumber r2 = r * r; if (right - left <= nodeSize) { for (auto i = left; i <= right; i++) { const TNumber x = std::get<0>(points[i].coords); const TNumber y = std::get<1>(points[i].coords); if (sqDist(x, y, qx, qy) <= r2) visitor(points[i]); } return; } const TIndex m = (left + right) >> 1; const TNumber x = std::get<0>(points[m].coords); const TNumber y = std::get<1>(points[m].coords); if (sqDist(x, y, qx, qy) <= r2) visitor(points[m]); if (axis == 0 ? qx - r <= x : qy - r <= y) within(qx, qy, r, visitor, left, m - 1, (axis + 1) % 2); if (axis == 0 ? qx + r >= x : qy + r >= y) within(qx, qy, r, visitor, m + 1, right, (axis + 1) % 2); } void sortKD(const TIndex left, const TIndex right, const std::uint8_t axis) { if (right - left <= nodeSize) return; const TIndex m = (left + right) >> 1; if (axis == 0) { select<0>(m, left, right); } else { select<1>(m, left, right); } sortKD(left, m - 1, (axis + 1) % 2); sortKD(m + 1, right, (axis + 1) % 2); } template void select(const TIndex k, TIndex left, TIndex right) { while (right > left) { if (right - left > 600) { const double n = static_cast(right - left + 1); const double m = static_cast(k - left + 1); const double z = std::log(n); const double s = 0.5 * std::exp(2 * z / 3); const double r = k - m * s / n + 0.5 * std::sqrt(z * s * (1 - s / n)) * (2 * m < n ? -1 : 1); select(k, std::max(left, TIndex(r)), std::min(right, TIndex(r + s))); } const TNumber t = std::get(points[k].coords); TIndex i = left; TIndex j = right; swapItem(left, k); if (std::get(points[right].coords) > t) swapItem(left, right); while (i < j) { swapItem(i++, j--); while (std::get(points[i].coords) < t) i++; while (std::get(points[j].coords) > t) j--; } if (std::get(points[left].coords) == t) swapItem(left, j); else { swapItem(++j, right); } if (j <= k) left = j + 1; if (k <= j) right = j - 1; } } void swapItem(const TIndex i, const TIndex j) { // std::iter_swap(ids.begin() + static_cast(i), ids.begin() + static_cast(j)); std::iter_swap(points.begin() + static_cast(i), points.begin() + static_cast(j)); } TNumber sqDist(const TNumber ax, const TNumber ay, const TNumber bx, const TNumber by) const { auto dx = ax - bx; auto dy = ay - by; return dx * dx + dy * dy; } }; } // namespace kdbush