Medial Code Documentation
Loading...
Searching...
No Matches
cache.h
1
4#ifndef XGBOOST_CACHE_H_
5#define XGBOOST_CACHE_H_
6
7#include <xgboost/logging.h> // for CHECK_EQ, CHECK
8
9#include <cstddef> // for size_t
10#include <memory> // for weak_ptr, shared_ptr, make_shared
11#include <mutex> // for mutex, lock_guard
12#include <queue> // for queue
13#include <thread> // for thread
14#include <unordered_map> // for unordered_map
15#include <utility> // for move
16#include <vector> // for vector
17
18namespace xgboost {
19class DMatrix;
25template <typename CacheT>
27 public:
28 struct Item {
29 // A weak pointer for checking whether the DMatrix object has expired.
30 std::weak_ptr<DMatrix> ref;
31 // The cached item
32 std::shared_ptr<CacheT> value;
33
34 CacheT const& Value() const { return *value; }
35 CacheT& Value() { return *value; }
36
37 Item(std::shared_ptr<DMatrix> m, std::shared_ptr<CacheT> v) : ref{m}, value{std::move(v)} {}
38 };
39
40 static constexpr std::size_t DefaultSize() { return 32; }
41
42 private:
43 mutable std::mutex lock_;
44
45 protected:
46 struct Key {
47 DMatrix const* ptr;
48 std::thread::id const thread_id;
49
50 bool operator==(Key const& that) const {
51 return ptr == that.ptr && thread_id == that.thread_id;
52 }
53 };
54 struct Hash {
55 std::size_t operator()(Key const& key) const noexcept {
56 std::size_t f = std::hash<DMatrix const*>()(key.ptr);
57 std::size_t s = std::hash<std::thread::id>()(key.thread_id);
58 if (f == s) {
59 return f;
60 }
61 return f ^ s;
62 }
63 };
64
65 std::unordered_map<Key, Item, Hash> container_;
66 std::queue<Key> queue_;
67 std::size_t max_size_;
68
69 void CheckConsistent() const { CHECK_EQ(queue_.size(), container_.size()); }
70
71 void ClearExpired() {
72 // Clear expired entries
73 this->CheckConsistent();
74 std::vector<Key> expired;
75 std::queue<Key> remained;
76
77 while (!queue_.empty()) {
78 auto p_fmat = queue_.front();
79 auto it = container_.find(p_fmat);
80 CHECK(it != container_.cend());
81 if (it->second.ref.expired()) {
82 expired.push_back(it->first);
83 } else {
84 remained.push(it->first);
85 }
86 queue_.pop();
87 }
88 CHECK(queue_.empty());
89 CHECK_EQ(remained.size() + expired.size(), container_.size());
90
91 for (auto const& key : expired) {
92 container_.erase(key);
93 }
94 while (!remained.empty()) {
95 auto p_fmat = remained.front();
96 queue_.push(p_fmat);
97 remained.pop();
98 }
99 this->CheckConsistent();
100 }
101
102 void ClearExcess() {
103 this->CheckConsistent();
104 // clear half of the entries to prevent repeatingly clearing cache.
105 std::size_t half_size = max_size_ / 2;
106 while (queue_.size() >= half_size && !queue_.empty()) {
107 auto p_fmat = queue_.front();
108 queue_.pop();
109 container_.erase(p_fmat);
110 }
111 this->CheckConsistent();
112 }
113
114 public:
118 explicit DMatrixCache(std::size_t cache_size) : max_size_{cache_size} {}
119
120 DMatrixCache& operator=(DMatrixCache&& that) {
121 CHECK(lock_.try_lock());
122 lock_.unlock();
123 CHECK(that.lock_.try_lock());
124 that.lock_.unlock();
125 std::swap(this->container_, that.container_);
126 std::swap(this->queue_, that.queue_);
127 std::swap(this->max_size_, that.max_size_);
128 return *this;
129 }
130
144 template <typename... Args>
145 std::shared_ptr<CacheT> CacheItem(std::shared_ptr<DMatrix> m, Args const&... args) {
146 CHECK(m);
147 std::lock_guard<std::mutex> guard{lock_};
148
149 this->ClearExpired();
150 if (container_.size() >= max_size_) {
151 this->ClearExcess();
152 }
153 // after clear, cache size < max_size
154 CHECK_LT(container_.size(), max_size_);
155 auto key = Key{m.get(), std::this_thread::get_id()};
156 auto it = container_.find(key);
157 if (it == container_.cend()) {
158 // after the new DMatrix, cache size is at most max_size
159 container_.emplace(key, Item{m, std::make_shared<CacheT>(args...)});
160 queue_.emplace(key);
161 }
162 return container_.at(key).value;
163 }
173 template <typename... Args>
174 std::shared_ptr<CacheT> ResetItem(std::shared_ptr<DMatrix> m, Args const&... args) {
175 std::lock_guard<std::mutex> guard{lock_};
176 CheckConsistent();
177 auto key = Key{m.get(), std::this_thread::get_id()};
178 auto it = container_.find(key);
179 CHECK(it != container_.cend());
180 it->second = {m, std::make_shared<CacheT>(args...)};
181 CheckConsistent();
182 return it->second.value;
183 }
188 decltype(container_) const& Container() {
189 std::lock_guard<std::mutex> guard{lock_};
190
191 this->ClearExpired();
192 return container_;
193 }
194
195 std::shared_ptr<CacheT> Entry(DMatrix const* m) const {
196 std::lock_guard<std::mutex> guard{lock_};
197 auto key = Key{m, std::this_thread::get_id()};
198 CHECK(container_.find(key) != container_.cend());
199 CHECK(!container_.at(key).ref.expired());
200 return container_.at(key).value;
201 }
202};
203} // namespace xgboost
204#endif // XGBOOST_CACHE_H_
Thread-aware FIFO cache for DMatrix related data.
Definition cache.h:26
decltype(container_) const & Container()
Get a const reference to the underlying hash map.
Definition cache.h:188
DMatrixCache(std::size_t cache_size)
Definition cache.h:118
std::shared_ptr< CacheT > CacheItem(std::shared_ptr< DMatrix > m, Args const &... args)
Cache a new DMatrix if it's not in the cache already.
Definition cache.h:145
std::shared_ptr< CacheT > ResetItem(std::shared_ptr< DMatrix > m, Args const &... args)
Re-initialize the item in cache.
Definition cache.h:174
Internal data structured used by XGBoost during training.
Definition data.h:509
Definition json.h:26
defines console logging options for xgboost. Use to enforce unified print behavior.
NLOHMANN_BASIC_JSON_TPL_DECLARATION void swap(nlohmann::NLOHMANN_BASIC_JSON_TPL &j1, nlohmann::NLOHMANN_BASIC_JSON_TPL &j2) noexcept(//NOLINT(readability-inconsistent-declaration-parameter-name, cert-dcl58-cpp) is_nothrow_move_constructible< nlohmann::NLOHMANN_BASIC_JSON_TPL >::value &&//NOLINT(misc-redundant-expression) is_nothrow_move_assignable< nlohmann::NLOHMANN_BASIC_JSON_TPL >::value)
exchanges the values of two JSON objects
Definition json.hpp:24418
namespace of xgboost
Definition base.h:90
Definition cache.h:54
Definition cache.h:28
Definition cache.h:46
Element from a sparse vector.
Definition data.h:216