Medial Code Documentation
Loading...
Searching...
No Matches
gblinear_model.h
1
4#pragma once
5#include <dmlc/io.h>
6#include <dmlc/parameter.h>
7#include <xgboost/learner.h>
8
9#include <vector>
10#include <string>
11#include <cstring>
12
13#include "xgboost/base.h"
14#include "xgboost/feature_map.h"
15#include "xgboost/model.h"
16#include "xgboost/json.h"
17#include "xgboost/parameter.h"
18
19namespace xgboost {
20class Json;
21namespace gbm {
22// Deprecated in 1.0.0. model parameter. Only staying here for compatible binary model IO.
23struct DeprecatedGBLinearModelParam : public dmlc::Parameter<DeprecatedGBLinearModelParam> {
24 // number of feature dimension
25 uint32_t deprecated_num_feature;
26 // deprecated. use learner_model_param_->num_output_group.
27 int32_t deprecated_num_output_group;
28 // reserved field
29 int32_t reserved[32];
30 // constructor
32 static_assert(sizeof(*this) == sizeof(int32_t) * 34,
33 "Model parameter size can not be changed.");
34 std::memset(this, 0, sizeof(DeprecatedGBLinearModelParam));
35 }
36
37 DMLC_DECLARE_PARAMETER(DeprecatedGBLinearModelParam) {
38 DMLC_DECLARE_FIELD(deprecated_num_feature);
39 DMLC_DECLARE_FIELD(deprecated_num_output_group);
40 }
41};
42
43// model for linear booster
44class GBLinearModel : public Model {
45 private:
46 // Deprecated in 1.0.0
48
49 public:
50 int32_t num_boosted_rounds{0};
51 LearnerModelParam const* learner_model_param;
52
53 public:
54 explicit GBLinearModel(LearnerModelParam const *learner_model_param)
55 : learner_model_param{learner_model_param} {}
56 void Configure(Args const &) { }
57
58 // weight for each of feature, bias is the last one
59 std::vector<bst_float> weight;
60 // initialize the model parameter
61 inline void LazyInitModel() {
62 if (!weight.empty()) {
63 return;
64 }
65 // bias is the last weight
66 weight.resize((learner_model_param->num_feature + 1) *
67 learner_model_param->num_output_group);
68 std::fill(weight.begin(), weight.end(), 0.0f);
69 }
70
71 void SaveModel(Json *p_out) const override;
72 void LoadModel(Json const &in) override;
73
74 // save the model to file
75 void Save(dmlc::Stream *fo) const {
76 fo->Write(&param_, sizeof(param_));
77 fo->Write(weight);
78 }
79 // load model from file
80 void Load(dmlc::Stream *fi) {
81 CHECK_EQ(fi->Read(&param_, sizeof(param_)), sizeof(param_));
82 fi->Read(&weight);
83 }
84
85 // model bias
86 inline bst_float *Bias() {
87 return &weight[learner_model_param->num_feature *
88 learner_model_param->num_output_group];
89 }
90 inline const bst_float *Bias() const {
91 return &weight[learner_model_param->num_feature *
92 learner_model_param->num_output_group];
93 }
94 // get i-th weight
95 inline bst_float *operator[](size_t i) {
96 return &weight[i * learner_model_param->num_output_group];
97 }
98 inline const bst_float *operator[](size_t i) const {
99 return &weight[i * learner_model_param->num_output_group];
100 }
101
102 std::vector<std::string> DumpModel(const FeatureMap &, bool,
103 std::string format) const {
104 const int ngroup = learner_model_param->num_output_group;
105 const unsigned nfeature = learner_model_param->num_feature;
106
107 std::stringstream fo("");
108 if (format == "json") {
109 fo << " { \"bias\": [" << std::endl;
110 for (int gid = 0; gid < ngroup; ++gid) {
111 if (gid != 0) {
112 fo << "," << std::endl;
113 }
114 fo << " " << this->Bias()[gid];
115 }
116 fo << std::endl
117 << " ]," << std::endl
118 << " \"weight\": [" << std::endl;
119 for (unsigned i = 0; i < nfeature; ++i) {
120 for (int gid = 0; gid < ngroup; ++gid) {
121 if (i != 0 || gid != 0) {
122 fo << "," << std::endl;
123 }
124 fo << " " << (*this)[i][gid];
125 }
126 }
127 fo << std::endl << " ]" << std::endl << " }";
128 } else {
129 fo << "bias:\n";
130 for (int gid = 0; gid < ngroup; ++gid) {
131 fo << this->Bias()[gid] << std::endl;
132 }
133 fo << "weight:\n";
134 for (unsigned i = 0; i < nfeature; ++i) {
135 for (int gid = 0; gid < ngroup; ++gid) {
136 fo << (*this)[i][gid] << std::endl;
137 }
138 }
139 }
140 std::vector<std::string> v;
141 v.push_back(fo.str());
142 return v;
143 }
144};
145
146} // namespace gbm
147} // namespace xgboost
interface of stream I/O for serialization
Definition io.h:30
Feature map data structure to help text model dump. TODO(tqchen) consider make it even more lightweig...
Definition feature_map.h:22
Data structure representing JSON format.
Definition json.h:357
Definition gblinear_model.h:44
void SaveModel(Json *p_out) const override
saves the model config to a JSON object
Definition gblinear_model.cc:13
void LoadModel(Json const &in) override
load the model from a JSON object
Definition gblinear_model.cc:23
defines serializable interface of dmlc
Provide lightweight util to do parameter setup and checking.
Feature map data structure to help visualization and model dump.
Copyright 2015-2023 by XGBoost Contributors.
macro for using C++11 enum class as DMLC parameter
Copyright 2015-2023 by XGBoost Contributors.
Defines the abstract interface for different components in XGBoost.
namespace of xgboost
Definition base.h:90
float bst_float
float type, used for storing statistics
Definition base.h:97
Basic model parameters, used to describe the booster.
Definition learner.h:291
std::uint32_t num_output_group
The number of classes or targets.
Definition learner.h:307
bst_feature_t num_feature
The number of features.
Definition learner.h:303
Definition model.h:17
Definition gblinear_model.h:23