Medial Code Documentation
Loading...
Searching...
No Matches
algorithm.h
1
4#ifndef XGBOOST_COMMON_ALGORITHM_H_
5#define XGBOOST_COMMON_ALGORITHM_H_
6#include <algorithm> // upper_bound, stable_sort, sort, max
7#include <cinttypes> // size_t
8#include <functional> // less
9#include <iterator> // iterator_traits, distance
10#include <vector> // vector
11
12#include "numeric.h" // Iota
13#include "xgboost/context.h" // Context
14
15// clang with libstdc++ works as well
16#if defined(__GNUC__) && (__GNUC__ >= 4) && !defined(__sun) && !defined(sun) && \
17 !defined(__APPLE__) && __has_include(<omp.h>) && __has_include(<parallel/algorithm>)
18#define GCC_HAS_PARALLEL 1
19#endif // GLIC_VERSION
20
21#if defined(_MSC_VER) && !defined(__INTEL_COMPILER)
22#define MSVC_HAS_PARALLEL 1
23#endif // MSC
24
25#if defined(GCC_HAS_PARALLEL)
26#include <parallel/algorithm>
27#elif defined(MSVC_HAS_PARALLEL)
28#include <ppl.h>
29#endif // GLIBC VERSION
30
31namespace xgboost {
32namespace common {
33template <typename It, typename Idx>
34auto SegmentId(It first, It last, Idx idx) {
35 std::size_t segment_id = std::upper_bound(first, last, idx) - 1 - first;
36 return segment_id;
37}
38
39template <typename Iter, typename Comp>
40void StableSort(Context const *ctx, Iter begin, Iter end, Comp &&comp) {
41 if (ctx->Threads() > 1) {
42#if defined(GCC_HAS_PARALLEL)
43 __gnu_parallel::stable_sort(begin, end, comp,
44 __gnu_parallel::default_parallel_tag(ctx->Threads()));
45#else
46 // the only stable sort is radix sort for msvc ppl.
47 std::stable_sort(begin, end, comp);
48#endif // GLIBC VERSION
49 } else {
50 std::stable_sort(begin, end, comp);
51 }
52}
53
54template <typename Iter, typename Comp>
55void Sort(Context const *ctx, Iter begin, Iter end, Comp comp) {
56 if (ctx->Threads() > 1) {
57#if defined(GCC_HAS_PARALLEL)
58 __gnu_parallel::sort(begin, end, comp, __gnu_parallel::default_parallel_tag(ctx->Threads()));
59#elif defined(MSVC_HAS_PARALLEL)
60 auto n = std::distance(begin, end);
61 // use chunk size as hint to number of threads. No local policy/scheduler input with the
62 // concurrency module.
63 std::size_t chunk_size = n / ctx->Threads();
64 // 2048 is the default of msvc ppl as of v2022.
65 chunk_size = std::max(chunk_size, static_cast<std::size_t>(2048));
66 concurrency::parallel_sort(begin, end, comp, chunk_size);
67#else
68 std::sort(begin, end, comp);
69#endif // GLIBC VERSION
70 } else {
71 std::sort(begin, end, comp);
72 }
73}
74
75template <typename Idx, typename Iter, typename V = typename std::iterator_traits<Iter>::value_type,
76 typename Comp = std::less<V>>
77std::vector<Idx> ArgSort(Context const *ctx, Iter begin, Iter end, Comp comp = std::less<V>{}) {
78 CHECK(ctx->IsCPU());
79 auto n = std::distance(begin, end);
80 std::vector<Idx> result(n);
81 Iota(ctx, result.begin(), result.end(), 0);
82 auto op = [&](Idx const &l, Idx const &r) { return comp(begin[l], begin[r]); };
83 StableSort(ctx, result.begin(), result.end(), op);
84 return result;
85}
86} // namespace common
87} // namespace xgboost
88
89#if defined(GCC_HAS_PARALLEL)
90#undef GCC_HAS_PARALLEL
91#endif // defined(GCC_HAS_PARALLEL)
92
93#if defined(MSVC_HAS_PARALLEL)
94#undef MSVC_HAS_PARALLEL
95#endif // defined(MSVC_HAS_PARALLEL)
96
97#endif // XGBOOST_COMMON_ALGORITHM_H_
Copyright 2014-2023, XGBoost Contributors.
namespace of xgboost
Definition base.h:90
std::int32_t Threads() const
Returns the automatically chosen number of threads based on the nthread parameter and the system sett...
Definition context.cc:203
bool IsCPU() const
Is XGBoost running on CPU?
Definition context.h:133