Medial Code Documentation
Loading...
Searching...
No Matches
observer.h
Go to the documentation of this file.
1
5#ifndef XGBOOST_COMMON_OBSERVER_H_
6#define XGBOOST_COMMON_OBSERVER_H_
7
8#include <iostream>
9#include <algorithm>
10#include <limits>
11#include <string>
12#include <vector>
13
15#include "xgboost/parameter.h"
16#include "xgboost/json.h"
17#include "xgboost/base.h"
18#include "xgboost/tree_model.h"
19
20#if defined(XGBOOST_STRICT_R_MODE) && XGBOOST_STRICT_R_MODE == 1
21#define OBSERVER_PRINT LOG(INFO)
22#define OBSERVER_ENDL ""
23#define OBSERVER_NEWLINE ""
24#else // defined(XGBOOST_STRICT_R_MODE) && XGBOOST_STRICT_R_MODE == 1
25#define OBSERVER_PRINT std::cout << std::setprecision(17)
26#define OBSERVER_ENDL std::endl
27#define OBSERVER_NEWLINE "\n"
28#endif // defined(XGBOOST_STRICT_R_MODE) && XGBOOST_STRICT_R_MODE == 1
29
30namespace xgboost {
31/*\brief An observer for logging internal data structures.
32 *
33 * This class is designed to be `diff` tool friendly, which means it uses plain
34 * `std::cout` for printing to avoid the time information emitted by `LOG(DEBUG)` or
35 * similiar facilities. Exception: use `LOG(INFO)` for the R package, to comply
36 * with CRAN policy.
37 */
39#if defined(XGBOOST_USE_DEBUG_OUTPUT)
40 bool constexpr static kObserve {true};
41#else
42 bool constexpr static kObserve {false};
43#endif // defined(XGBOOST_USE_DEBUG_OUTPUT)
44
45 public:
46 void Update(int32_t iter) const {
47 if (XGBOOST_EXPECT(!kObserve, true)) { return; }
48 OBSERVER_PRINT << "Iter: " << iter << OBSERVER_ENDL;
49 }
50 /*\brief Observe tree. */
51 void Observe(RegTree const& tree) {
52 if (XGBOOST_EXPECT(!kObserve, true)) { return; }
53 OBSERVER_PRINT << "Tree:" << OBSERVER_ENDL;
54 Json j_tree {Object()};
55 tree.SaveModel(&j_tree);
56 std::string str;
57 Json::Dump(j_tree, &str);
58 OBSERVER_PRINT << str << OBSERVER_ENDL;
59 }
60 /*\brief Observe tree. */
61 void Observe(RegTree const* p_tree) {
62 if (XGBOOST_EXPECT(!kObserve, true)) { return; }
63 auto const& tree = *p_tree;
64 this->Observe(tree);
65 }
66 template <typename T>
67 void Observe(common::Span<T> span, std::string name,
68 size_t n = std::numeric_limits<std::size_t>::max()) {
69 std::vector<T> copy(span.size());
70 std::copy(span.cbegin(), span.cend(), copy.begin());
71 this->Observe(copy, name, n);
72 }
73 /*\brief Observe data hosted by `std::vector'. */
74 template <typename T>
75 void Observe(std::vector<T> const& h_vec, std::string name,
76 size_t n = std::numeric_limits<std::size_t>::max()) const {
77 if (XGBOOST_EXPECT(!kObserve, true)) { return; }
78 OBSERVER_PRINT << "Procedure: " << name << OBSERVER_ENDL;
79
80 for (size_t i = 0; i < h_vec.size(); ++i) {
81 OBSERVER_PRINT << h_vec[i] << ", ";
82 if (i % 8 == 0 && i != 0) {
83 OBSERVER_PRINT << OBSERVER_NEWLINE;
84 }
85 if ((i + 1) == n) {
86 break;
87 }
88 }
89 OBSERVER_PRINT << OBSERVER_ENDL;
90 }
91 /*\brief Observe data hosted by `HostDeviceVector'. */
92 template <typename T>
93 void Observe(HostDeviceVector<T> const& vec, std::string name,
94 size_t n = std::numeric_limits<std::size_t>::max()) const {
95 if (XGBOOST_EXPECT(!kObserve, true)) { return; }
96 auto const& h_vec = vec.HostVector();
97 this->Observe(h_vec, name, n);
98 }
99 template <typename T>
100 void Observe(HostDeviceVector<T>* vec, std::string name,
101 size_t n = std::numeric_limits<std::size_t>::max()) const {
102 if (XGBOOST_EXPECT(!kObserve, true)) { return; }
103 this->Observe(*vec, name, n);
104 }
105
106 /*\brief Observe objects with `XGBoostParamer' type. */
107 template <typename Parameter,
108 typename std::enable_if<
109 std::is_base_of<XGBoostParameter<Parameter>, Parameter>::value>::type* = nullptr>
110 void Observe(const Parameter &p, std::string name) const {
111 if (XGBOOST_EXPECT(!kObserve, true)) { return; }
112
113 Json obj {toJson(p)};
114 OBSERVER_PRINT << "Parameter: " << name << ":\n" << obj << OBSERVER_ENDL;
115 }
116 /*\brief Observe parameters provided by users. */
117 void Observe(Args const& args) const {
118 if (XGBOOST_EXPECT(!kObserve, true)) { return; }
119
120 for (auto kv : args) {
121 OBSERVER_PRINT << kv.first << ": " << kv.second << OBSERVER_NEWLINE;
122 }
123 OBSERVER_PRINT << OBSERVER_ENDL;
124 }
125
126 /*\brief Get a global instance. */
127 static TrainingObserver& Instance() {
128 static TrainingObserver observer;
129 return observer;
130 }
131};
132} // namespace xgboost
133#endif // XGBOOST_COMMON_OBSERVER_H_
Definition host_device_vector.h:87
Definition json.h:190
Data structure representing JSON format.
Definition json.h:357
static void Dump(Json json, std::string *out, std::ios::openmode mode=std::ios::out)
Encode the JSON object.
Definition json.cc:669
define regression tree to be the most common tree model.
Definition tree_model.h:158
void SaveModel(Json *out) const override
saves the model config to a JSON object
Definition tree_model.cc:1142
Definition observer.h:38
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition span.h:424
A device-and-host vector abstraction layer.
Copyright 2015-2023 by XGBoost Contributors.
macro for using C++11 enum class as DMLC parameter
namespace of xgboost
Definition base.h:90
Copyright 2014-2023 by Contributors.