20 :num_data_(num_data), num_leaves_(
num_leaves) {
21 leaf_begin_.resize(num_leaves_);
22 leaf_count_.resize(num_leaves_);
23 indices_.resize(num_data_);
24 temp_left_indices_.resize(num_data_);
25 temp_right_indices_.resize(num_data_);
26 used_data_indices_ =
nullptr;
30 num_threads_ = omp_get_num_threads();
32 offsets_buf_.resize(num_threads_);
33 left_cnts_buf_.resize(num_threads_);
34 right_cnts_buf_.resize(num_threads_);
35 left_write_pos_buf_.resize(num_threads_);
36 right_write_pos_buf_.resize(num_threads_);
41 leaf_begin_.resize(num_leaves_);
42 leaf_count_.resize(num_leaves_);
44 void ResetNumData(
int num_data) {
46 indices_.resize(num_data_);
47 temp_left_indices_.resize(num_data_);
48 temp_right_indices_.resize(num_data_);
57 std::fill(leaf_begin_.begin(), leaf_begin_.end(), 0);
58 std::fill(leaf_count_.begin(), leaf_count_.end(), 0);
59 if (used_data_indices_ ==
nullptr) {
61 leaf_count_[0] = num_data_;
62 #pragma omp parallel for schedule(static)
68 leaf_count_[0] = used_data_count_;
69 std::memcpy(indices_.data(), used_data_indices_, used_data_count_ *
sizeof(
data_size_t));
73 void ResetByLeafPred(
const std::vector<int>& leaf_pred,
int num_leaves) {
75 std::vector<std::vector<data_size_t>> indices_per_leaf(num_leaves_);
76 for (
data_size_t i = 0; i < static_cast<data_size_t>(leaf_pred.size()); ++i) {
77 indices_per_leaf[leaf_pred[i]].push_back(i);
80 for (
int i = 0; i < num_leaves_; ++i) {
81 leaf_begin_[i] = offset;
82 leaf_count_[i] =
static_cast<data_size_t>(indices_per_leaf[i].size());
83 std::copy(indices_per_leaf[i].begin(), indices_per_leaf[i].end(), indices_.begin() + leaf_begin_[i]);
84 offset += leaf_count_[i];
97 *out_len = leaf_count_[leaf];
98 return indices_.data() + begin;
108 void Split(
int leaf,
const Dataset* dataset,
int feature,
const uint32_t* threshold,
int num_threshold,
bool default_left,
int right_leaf) {
114 data_size_t inner_size = (cnt + num_threads_ - 1) / num_threads_;
115 if (inner_size < min_inner_size) { inner_size = min_inner_size; }
118 #pragma omp parallel for schedule(static, 1)
119 for (
int i = 0; i < num_threads_; ++i) {
121 left_cnts_buf_[i] = 0;
122 right_cnts_buf_[i] = 0;
124 if (cur_start > cnt) {
continue; }
126 if (cur_start + cur_cnt > cnt) { cur_cnt = cnt - cur_start; }
128 data_size_t cur_left_count = dataset->Split(feature, threshold, num_threshold, default_left, indices_.data() + begin + cur_start, cur_cnt,
129 temp_left_indices_.data() + cur_start, temp_right_indices_.data() + cur_start);
130 offsets_buf_[i] = cur_start;
131 left_cnts_buf_[i] = cur_left_count;
132 right_cnts_buf_[i] = cur_cnt - cur_left_count;
137 left_write_pos_buf_[0] = 0;
138 right_write_pos_buf_[0] = 0;
139 for (
int i = 1; i < num_threads_; ++i) {
140 left_write_pos_buf_[i] = left_write_pos_buf_[i - 1] + left_cnts_buf_[i - 1];
141 right_write_pos_buf_[i] = right_write_pos_buf_[i - 1] + right_cnts_buf_[i - 1];
143 left_cnt = left_write_pos_buf_[num_threads_ - 1] + left_cnts_buf_[num_threads_ - 1];
145 #pragma omp parallel for schedule(static, 1)
146 for (
int i = 0; i < num_threads_; ++i) {
147 if (left_cnts_buf_[i] > 0) {
148 std::memcpy(indices_.data() + begin + left_write_pos_buf_[i],
149 temp_left_indices_.data() + offsets_buf_[i], left_cnts_buf_[i] *
sizeof(
data_size_t));
151 if (right_cnts_buf_[i] > 0) {
152 std::memcpy(indices_.data() + begin + left_cnt + right_write_pos_buf_[i],
153 temp_right_indices_.data() + offsets_buf_[i], right_cnts_buf_[i] *
sizeof(
data_size_t));
157 leaf_count_[leaf] = left_cnt;
158 leaf_begin_[right_leaf] = left_cnt + begin;
159 leaf_count_[right_leaf] = cnt - left_cnt;
168 used_data_indices_ = used_data_indices;
169 used_data_count_ = num_used_data;
186 const data_size_t* indices()
const {
return indices_.data(); }
197 std::vector<data_size_t> leaf_begin_;
199 std::vector<data_size_t> leaf_count_;
201 std::vector<data_size_t> indices_;
203 std::vector<data_size_t> temp_left_indices_;
205 std::vector<data_size_t> temp_right_indices_;
213 std::vector<data_size_t> offsets_buf_;
215 std::vector<data_size_t> left_cnts_buf_;
217 std::vector<data_size_t> right_cnts_buf_;
219 std::vector<data_size_t> left_write_pos_buf_;
221 std::vector<data_size_t> right_write_pos_buf_;
void Split(int leaf, const Dataset *dataset, int feature, const uint32_t *threshold, int num_threshold, bool default_left, int right_leaf)
Split the data.
Definition data_partition.hpp:108
void SetUsedDataIndices(const data_size_t *used_data_indices, data_size_t num_used_data)
SetLabelAt used data indices before training, used for bagging.
Definition data_partition.hpp:167
const data_size_t * GetIndexOnLeaf(int leaf, data_size_t *out_len) const
Get the data indices of one leaf.
Definition data_partition.hpp:94