Medial Code Documentation
Loading...
Searching...
No Matches
row_block.h
Go to the documentation of this file.
1
8#ifndef DMLC_DATA_ROW_BLOCK_H_
9#define DMLC_DATA_ROW_BLOCK_H_
10
11#include <dmlc/io.h>
12#include <dmlc/logging.h>
13#include <dmlc/data.h>
14#include <cstring>
15#include <vector>
16#include <limits>
17#include <algorithm>
18
19namespace dmlc {
20namespace data {
26template<typename IndexType, typename DType = real_t>
29 std::vector<size_t> offset;
31 std::vector<DType> label;
33 std::vector<real_t> weight;
35 std::vector<uint64_t> qid;
37 std::vector<IndexType> field;
39 std::vector<IndexType> index;
41 std::vector<DType> value;
43 IndexType max_field;
45 IndexType max_index;
46 // constructor
47 RowBlockContainer(void) {
48 this->Clear();
49 }
56 inline void Save(Stream *fo) const;
62 inline bool Load(Stream *fi);
64 inline void Clear(void) {
65 offset.clear(); offset.push_back(0);
66 label.clear(); field.clear(); index.clear(); value.clear(); weight.clear(); qid.clear();
67 max_field = 0;
68 max_index = 0;
69 }
71 inline size_t Size(void) const {
72 return offset.size() - 1;
73 }
75 inline size_t MemCostBytes(void) const {
76 return offset.size() * sizeof(size_t) +
77 label.size() * sizeof(real_t) +
78 weight.size() * sizeof(real_t) +
79 qid.size() * sizeof(size_t) +
80 field.size() * sizeof(IndexType) +
81 index.size() * sizeof(IndexType) +
82 value.size() * sizeof(DType);
83 }
89 template<typename I>
90 inline void Push(Row<I, DType> row) {
91 label.push_back(row.get_label());
92 weight.push_back(row.get_weight());
93 qid.push_back(row.get_qid());
94 if (row.field != NULL) {
95 for (size_t i = 0; i < row.length; ++i) {
96 CHECK_LE(row.field[i], std::numeric_limits<IndexType>::max())
97 << "field exceed numeric bound of current type";
98 IndexType field_id = static_cast<IndexType>(row.field[i]);
99 field.push_back(field_id);
100 max_field = std::max(max_field, field_id);
101 }
102 }
103 for (size_t i = 0; i < row.length; ++i) {
104 CHECK_LE(row.index[i], std::numeric_limits<IndexType>::max())
105 << "index exceed numeric bound of current type";
106 IndexType findex = static_cast<IndexType>(row.index[i]);
107 index.push_back(findex);
108 max_index = std::max(max_index, findex);
109 }
110 if (row.value != NULL) {
111 for (size_t i = 0; i < row.length; ++i) {
112 value.push_back(row.value[i]);
113 }
114 }
115 offset.push_back(index.size());
116 }
122 template<typename I>
123 inline void Push(RowBlock<I, DType> batch) {
124 size_t size = label.size();
125 label.resize(label.size() + batch.size);
126 std::memcpy(BeginPtr(label) + size, batch.label,
127 batch.size * sizeof(DType));
128 if (batch.weight != NULL) {
129 weight.insert(weight.end(), batch.weight, batch.weight + batch.size);
130 }
131 if (batch.qid != NULL) {
132 qid.insert(qid.end(), batch.qid, batch.qid + batch.size);
133 }
134 size_t ndata = batch.offset[batch.size] - batch.offset[0];
135 if (batch.field != NULL) {
136 field.resize(field.size() + ndata);
137 IndexType *fhead = BeginPtr(field) + offset.back();
138 for (size_t i = 0; i < ndata; ++i) {
139 CHECK_LE(batch.field[i], std::numeric_limits<IndexType>::max())
140 << "field exceed numeric bound of current type";
141 IndexType field_id = static_cast<IndexType>(batch.field[i]);
142 fhead[i] = field_id;
143 max_field = std::max(max_field, field_id);
144 }
145 }
146 index.resize(index.size() + ndata);
147 IndexType *ihead = BeginPtr(index) + offset.back();
148 for (size_t i = 0; i < ndata; ++i) {
149 CHECK_LE(batch.index[i], std::numeric_limits<IndexType>::max())
150 << "index exceed numeric bound of current type";
151 IndexType findex = static_cast<IndexType>(batch.index[i]);
152 ihead[i] = findex;
153 max_index = std::max(max_index, findex);
154 }
155 if (batch.value != NULL) {
156 value.resize(value.size() + ndata);
157 std::memcpy(BeginPtr(value) + value.size() - ndata, batch.value,
158 ndata * sizeof(DType));
159 }
160 size_t shift = offset[size];
161 offset.resize(offset.size() + batch.size);
162 size_t *ohead = BeginPtr(offset) + size + 1;
163 for (size_t i = 0; i < batch.size; ++i) {
164 ohead[i] = shift + batch.offset[i + 1] - batch.offset[0];
165 }
166 }
167};
168
169template<typename IndexType, typename DType>
172 // consistency check
173 if (label.size()) {
174 CHECK_EQ(label.size() + 1, offset.size());
175 }
176 CHECK_EQ(offset.back(), index.size());
177 CHECK(offset.back() == value.size() || value.size() == 0);
179 data.size = offset.size() - 1;
180 data.offset = BeginPtr(offset);
181 data.label = BeginPtr(label);
182 data.weight = BeginPtr(weight);
183 data.qid = BeginPtr(qid);
184 data.field = BeginPtr(field);
185 data.index = BeginPtr(index);
186 data.value = BeginPtr(value);
187 return data;
188}
189template<typename IndexType, typename DType>
190inline void
192 fo->Write(offset);
193 fo->Write(label);
194 fo->Write(weight);
195 fo->Write(qid);
196 fo->Write(field);
197 fo->Write(index);
198 fo->Write(value);
199 fo->Write(&max_field, sizeof(IndexType));
200 fo->Write(&max_index, sizeof(IndexType));
201}
202template<typename IndexType, typename DType>
203inline bool
205 if (!fi->Read(&offset)) return false;
206 CHECK(fi->Read(&label)) << "Bad RowBlock format";
207 CHECK(fi->Read(&weight)) << "Bad RowBlock format";
208 CHECK(fi->Read(&qid)) << "Bad RowBlock format";
209 CHECK(fi->Read(&field)) << "Bad RowBlock format";
210 CHECK(fi->Read(&index)) << "Bad RowBlock format";
211 CHECK(fi->Read(&value)) << "Bad RowBlock format";
212 CHECK(fi->Read(&max_field, sizeof(IndexType))) << "Bad RowBlock format";
213 CHECK(fi->Read(&max_index, sizeof(IndexType))) << "Bad RowBlock format";
214 return true;
215}
216} // namespace data
217} // namespace dmlc
218#endif // DMLC_DATA_ROW_BLOCK_H_
one row of training instance
Definition data.h:74
interface of stream I/O for serialization
Definition io.h:30
defines common input data structure, and interface for handling the input data
defines serializable interface of dmlc
defines logging macros of dmlc allows use of GLOG, fall back to internal implementation when disabled
namespace for dmlc
Definition array_view.h:12
float real_t
this defines the float point that will be used to store feature values
Definition data.h:26
T * BeginPtr(std::vector< T > &vec)
safely get the beginning address of a vector
Definition base.h:284
a block of data, containing several rows in sparse matrix This is useful for (streaming-sxtyle) algor...
Definition data.h:175
const DType * label
array[size] label of each instance
Definition data.h:181
size_t size
batch size
Definition data.h:177
const real_t * weight
With weight: array[size] label of each instance, otherwise nullptr.
Definition data.h:183
const IndexType * index
feature index
Definition data.h:189
const DType * value
feature value, can be NULL, indicating all values are 1
Definition data.h:191
const uint64_t * qid
With qid: array[size] session id of each instance, otherwise nullptr.
Definition data.h:185
const IndexType * field
field id
Definition data.h:187
const size_t * offset
array[size+1], row pointer to beginning of each rows
Definition data.h:179
dynamic data structure that holds a row block of data
Definition row_block.h:27
void Save(Stream *fo) const
write the row block to a binary stream
Definition row_block.h:191
std::vector< IndexType > index
feature index
Definition row_block.h:39
IndexType max_index
maximum value of index
Definition row_block.h:45
size_t Size(void) const
size of the data
Definition row_block.h:71
std::vector< size_t > offset
array[size+1], row pointer to beginning of each rows
Definition row_block.h:29
RowBlock< IndexType, DType > GetBlock(void) const
convert to a row block
Definition row_block.h:171
std::vector< IndexType > field
field index
Definition row_block.h:37
std::vector< uint64_t > qid
array[size] session-id of each instance
Definition row_block.h:35
void Push(Row< I, DType > row)
push the row into container
Definition row_block.h:90
bool Load(Stream *fi)
load row block from a binary stream
Definition row_block.h:204
std::vector< DType > value
feature value
Definition row_block.h:41
void Push(RowBlock< I, DType > batch)
push the row block into container
Definition row_block.h:123
void Clear(void)
clear the container
Definition row_block.h:64
IndexType max_field
maximum value of field
Definition row_block.h:43
size_t MemCostBytes(void) const
Definition row_block.h:75
std::vector< real_t > weight
array[size] weight of each instance
Definition row_block.h:33
std::vector< DType > label
array[size] label of each instance
Definition row_block.h:31