Medial Code Documentation
Loading...
Searching...
No Matches
parallel_tree_learner.h
1#ifndef LIGHTGBM_TREELEARNER_PARALLEL_TREE_LEARNER_H_
2#define LIGHTGBM_TREELEARNER_PARALLEL_TREE_LEARNER_H_
3
4#include "serial_tree_learner.h"
5#include "gpu_tree_learner.h"
6#include <LightGBM/network.h>
7
8#include <LightGBM/utils/array_args.h>
9
10#include <cstring>
11#include <vector>
12#include <memory>
13
14namespace LightGBM {
15
21template <typename TREELEARNER_T>
22class FeatureParallelTreeLearner: public TREELEARNER_T {
23public:
24 explicit FeatureParallelTreeLearner(const Config* config);
26 void Init(const Dataset* train_data, bool is_constant_hessian) override;
27
28protected:
29 void BeforeTrain() override;
30 void FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) override;
31private:
33 int rank_;
35 int num_machines_;
37 std::vector<char> input_buffer_;
39 std::vector<char> output_buffer_;
40};
41
47template <typename TREELEARNER_T>
48class DataParallelTreeLearner: public TREELEARNER_T {
49public:
50 explicit DataParallelTreeLearner(const Config* config);
52 void Init(const Dataset* train_data, bool is_constant_hessian) override;
53 void ResetConfig(const Config* config) override;
54
55protected:
56 void BeforeTrain() override;
57 void FindBestSplits() override;
58 void FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) override;
59 void Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) override;
60
61 inline data_size_t GetGlobalDataCountInLeaf(int leaf_idx) const override {
62 if (leaf_idx >= 0) {
63 return global_data_count_in_leaf_[leaf_idx];
64 } else {
65 return 0;
66 }
67 }
68
69private:
71 int rank_;
73 int num_machines_;
75 std::vector<char> input_buffer_;
77 std::vector<char> output_buffer_;
80 std::vector<bool> is_feature_aggregated_;
82 std::vector<comm_size_t> block_start_;
84 std::vector<comm_size_t> block_len_;
86 std::vector<comm_size_t> buffer_write_start_pos_;
88 std::vector<comm_size_t> buffer_read_start_pos_;
90 comm_size_t reduce_scatter_size_;
92 std::vector<data_size_t> global_data_count_in_leaf_;
93};
94
101template <typename TREELEARNER_T>
102class VotingParallelTreeLearner: public TREELEARNER_T {
103public:
104 explicit VotingParallelTreeLearner(const Config* config);
106 void Init(const Dataset* train_data, bool is_constant_hessian) override;
107 void ResetConfig(const Config* config) override;
108
109protected:
110 void BeforeTrain() override;
111 bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) override;
112 void FindBestSplits() override;
113 void FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) override;
114 void Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) override;
115
116 inline data_size_t GetGlobalDataCountInLeaf(int leaf_idx) const override {
117 if (leaf_idx >= 0) {
118 return global_data_count_in_leaf_[leaf_idx];
119 } else {
120 return 0;
121 }
122 }
129 void GlobalVoting(int leaf_idx, const std::vector<LightSplitInfo>& splits,
130 std::vector<int>* out);
136 void CopyLocalHistogram(const std::vector<int>& smaller_top_features,
137 const std::vector<int>& larger_top_features);
138
139private:
141 Config local_config_;
143 int top_k_;
145 int rank_;
147 int num_machines_;
149 std::vector<char> input_buffer_;
151 std::vector<char> output_buffer_;
154 std::vector<bool> smaller_is_feature_aggregated_;
157 std::vector<bool> larger_is_feature_aggregated_;
159 std::vector<comm_size_t> block_start_;
161 std::vector<comm_size_t> block_len_;
163 std::vector<comm_size_t> smaller_buffer_read_start_pos_;
165 std::vector<comm_size_t> larger_buffer_read_start_pos_;
167 comm_size_t reduce_scatter_size_;
169 std::vector<data_size_t> global_data_count_in_leaf_;
171 std::unique_ptr<LeafSplits> smaller_leaf_splits_global_;
173 std::unique_ptr<LeafSplits> larger_leaf_splits_global_;
175 std::unique_ptr<FeatureHistogram[]> smaller_leaf_histogram_array_global_;
177 std::unique_ptr<FeatureHistogram[]> larger_leaf_histogram_array_global_;
178
179 std::vector<HistogramBinEntry> smaller_leaf_histogram_data_;
180 std::vector<HistogramBinEntry> larger_leaf_histogram_data_;
181 std::vector<FeatureMetainfo> feature_metas_;
182};
183
184// To-do: reduce the communication cost by using bitset to communicate.
185inline void SyncUpGlobalBestSplit(char* input_buffer_, char* output_buffer_, SplitInfo* smaller_best_split, SplitInfo* larger_best_split, int max_cat_threshold) {
186 // sync global best info
187 int size = SplitInfo::Size(max_cat_threshold);
188 smaller_best_split->CopyTo(input_buffer_);
189 larger_best_split->CopyTo(input_buffer_ + size);
190 Network::Allreduce(input_buffer_, size * 2, size, output_buffer_,
191 [] (const char* src, char* dst, int size, comm_size_t len) {
192 comm_size_t used_size = 0;
193 LightSplitInfo p1, p2;
194 while (used_size < len) {
195 p1.CopyFrom(src);
196 p2.CopyFrom(dst);
197 if (p1 > p2) {
198 std::memcpy(dst, src, size);
199 }
200 src += size;
201 dst += size;
202 used_size += size;
203 }
204 });
205 // copy back
206 smaller_best_split->CopyFrom(output_buffer_);
207 larger_best_split->CopyFrom(output_buffer_ + size);
208}
209
210} // namespace LightGBM
211#endif // LightGBM_TREELEARNER_PARALLEL_TREE_LEARNER_H_
Data parallel learning algorithm. Workers use local data to construct histograms locally,...
Definition parallel_tree_learner.h:48
The main class of data set, which are used to traning or validation.
Definition dataset.h:278
Feature parallel learning algorithm. Different machine will find best split on different features,...
Definition parallel_tree_learner.h:22
static void Allreduce(char *input, comm_size_t input_size, int type_size, char *output, const ReduceFunction &reducer)
Perform all_reduce. if data size is small, will perform AllreduceByAllGather, else with call ReduceSc...
Definition network.cpp:64
Tree model.
Definition tree.h:20
Voting based data parallel learning algorithm. Like data parallel, but not aggregate histograms for a...
Definition parallel_tree_learner.h:102
void GlobalVoting(int leaf_idx, const std::vector< LightSplitInfo > &splits, std::vector< int > *out)
Perform global voting.
Definition voting_parallel_tree_learner.cpp:166
void CopyLocalHistogram(const std::vector< int > &smaller_top_features, const std::vector< int > &larger_top_features)
Copy local histgram to buffer.
Definition voting_parallel_tree_learner.cpp:198
desc and descl2 fields must be written in reStructuredText format
Definition application.h:10
int32_t data_size_t
Type of data size, it is better to use signed type.
Definition meta.h:14
Definition config.h:27
Definition split_info.hpp:190
Used to store some information for gain split point.
Definition split_info.hpp:17