Medial Code Documentation
Loading...
Searching...
No Matches
MedSparseVec.h
1// MedSparseVec
2// ------------
3//
4// Dealing with the problem of efficiently holding, inserting and retrieving an element of type T given some unsigned int unique
5// key attached to it.
6//
7// Hence the problem is : given pairs (key_i, elem_i) , create a memory efficient and fast data structure V (mimicing a vector)
8// with the main API : V.get(key) -> returns the matching element
9// The problem is how to do this efficiently when key is very sparse.
10//
11// If the vector V is in the range [a,b] (b-a+1 length) we would like to be able to insert : v[key]=elem, and retrieve out_elem = V[given key]
12//
13// We would like our memory to be as low as possible.
14// This implementation will add 1.5bits for each possible range.
15// So if the range (always an unsigned int) of the min,max values of keys for V is [a,b], and N = b-a+1 ,
16// and there are actually only M<N T elements, we will use memory of size: 1.5N/8 + M*sizeof(T) bytes.
17//
18// since we want to save memory this is better than a simple array only if:
19// N*sizeof(T) > 1.5N/8 + M*sizeof(T)
20// or:
21// M < N - 1.5N/(8*sizeof(T)) =(for sizeof(T)=4) N - 1.5N/32 = 0.953N ,
22//
23// so...in practice even if we keep just an int of memory as an item, we start saving memory even if the table is 0.95 full !!!
24//
25// other options are to use map<int,T> but map is using lots of memory per key, something like 32 bytes....
26// hence map will be better only for much sparser cases:
27//
28// M*(sizeof(T)+32) < 1.5N/8 + M*sizeof(T)
29// or:
30// M < 1.5N/(8*32) = 3N/512 = 0.005N !!
31//
32// Another option to compare to is keep a vector of the the pair <uint, T> (maybe even sorted)
33// this costs M*(sizeof(T)+4) and is VERY slow in lookup
34// however even here we get
35// M*(sizeof(T)+4) < 1.5N/8 + M*sizeof(T)
36// or:
37// M < 1.5N/32 = 0.046N .... so in sparsness of the 5%-95% we will be better even than just KEEPING the data AS IS !!!!
38//
39// so for sparsness in the range of 0.5% to 95% this data structure will save memory.
40//
41// Another drawback on this data structure:
42// - in order to get a fast insert time, it is best to:
43// (1) set the range [a,b] in advance.
44// (2) insert the elements with the keys SORTED, such that if we enter V[i] after V[j] then i > j (or appeared before).
45// failing to do so will return an error (or a terrible insertion time).
46//
47// so the actual situation to start with when using this package is :
48// (1) have a vector of size M of keys K[i] , which are unsigned int and sorted.
49// (2) have a vector of size M of elements E[i] of type T, such that the key for E[i] is K[i].
50//
51// OR:
52// (1) make sure that whenever a new pair {k,e} is inserted then k is >= of all the keys that were before or appeared already (in which case we replace its
53// ... value with the new e.
54//
55//
56
57#ifndef __MED__SPARSE__VEC__
58#define __MED__SPARSE__VEC__
59
60#include <vector>
61// #include <immintrin.h>
62// #include <nmmintrin.h>
63#include <cstring>
64
65// Helper macro to use the standard GCC/Clang built-in for population count.
66// This works on both x86_64 and ARM64 (AArch64).
67
68#if defined(__CUDA_ARCH__)
69 #define MED_POPCOUNT(x) __popcll(x)
70#elif defined(__GNUC__) || defined(__clang__)
71 #define MED_POPCOUNT(x) __builtin_popcountll(x)
72#elif defined(_MSC_VER) && defined(_M_X64)
73 #define MED_POPCOUNT(x) __popcnt64(x)
74#else
75 #define MED_POPCOUNT(x) NativePopc(x)
76#endif
77
78using namespace std;
79
80#define MED_SPARSE_VEC_MAGIC_NUM 0x0102030405060708
81
82template <class T> class MedSparseVec {
83
84 public:
85
86 unsigned int min_val;
87 unsigned int max_val;
88 int max_set;
89 unsigned int max_key;
90 T def_val;
91
92 vector<unsigned int> counts;
93 vector<unsigned long long> is_in_bit;
94 vector<T> data;
95
96 void set_min(int _min) { min_val= _min; }
97 void set_max(int _max) { max_val= _max; max_set = 1; }
98 MedSparseVec() { min_val = 0; max_val = 0; max_set = 0; max_key = 0; init(); }
99 MedSparseVec(int _min) { min_val = _min; max_val = 0; max_set = 0; max_key = 0; init(); }
100 MedSparseVec(int _min, int _max) { min_val = _min; max_set = 1; max_val = _max; max_key = 0; init(); }
101 void set_def(const T val) { def_val = val; data[0] = def_val; }
102
103 inline T operator[] (const unsigned int key) const { return (*get(key)); }
104
105 inline T &operator[] (const unsigned int key) { return (*get(key)); }
106
107 void reserve(unsigned int size) {data.reserve(size);}
108
109 //------------------------------------------------------
110 // init major arrays
111 //------------------------------------------------------
112 void init() {
113 counts.resize(2, 0);
114 is_in_bit.resize(1, 0);
115 data.clear();
116// def_val = (T)0;
117 def_val = 0;
118 data.push_back(def_val);
119 }
120
121
122 //------------------------------------------------------
123 // clear a used sparse vec, going back to all defaults
124 void clear() {
125 counts.assign(2, 0);
126 is_in_bit.assign(1, 0ULL);
127 data.clear();
128 //def_val = (T)0;
129 def_val = 0;
130 data.push_back(def_val);
131 min_val = 0; max_val = 0; max_set = 0; max_key = 0;
132 }
133
134 //------------------------------------------------------
135 // get index of a data item - 0 is not found
136 //------------------------------------------------------
137 inline unsigned int get_ind(const unsigned int key)
138 {
139
140 if (key < min_val || key > max_key || (max_set && key>max_val))
141 return 0; // not in range
142
143 unsigned int mkey = key - min_val;
144 unsigned int i_count = mkey >> 6;
145 unsigned int i_bit = 63 - (mkey & 0x3f);
146 unsigned long long kbits = is_in_bit[i_count]>>i_bit;
147 //unsigned long long i_mask = (((unsigned long long)1)<<(i_bit));
148
149 //fprintf(stderr, "key=%d mkey=%d i_count=%d %d i_bit=%d kbits=%d ind= %d\n", key, mkey, i_count, counts[i_count], i_bit, kbits, -1); fflush(stderr);
150 //if ((is_in_bit[i_count]&i_mask) == 0)
151 if ((kbits & 0x1) == 0)
152 return 0; // no value inserted
153
154// unsigned int ind = counts[i_count] + (int)_mm_popcnt_u64(is_in_bit[i_count]>>i_bit);
155 unsigned int ind = counts[i_count] + (int)MED_POPCOUNT(kbits);
156
157 //fprintf(stderr, "key=%d mkey=%d i_count=%d %d i_bit=%d i_mask=%llx %llx pos = %d val = %d\n", key, mkey, i_count, counts[i_count], i_bit, i_mask, is_in_bit[i_count]>>i_bit, pos, data[pos]); fflush(stderr);
158
159 return ind;
160 }
161
162 //------------------------------------------------------
163 // insert a new element
164 //------------------------------------------------------
165 int insert(const unsigned int key, const T elem)
166 {
167 // first we check we are allowed to insert it
168 if (key<min_val || (max_set && key>max_val)) return -1;
169
170 // get the bit indexes of key
171 unsigned int mkey = key - min_val;
172 unsigned int i_count = mkey >> 6;
173 unsigned int i_bit = 63 - (mkey & 0x3f);
174 unsigned long long i_mask = (((unsigned long long)1)<<(i_bit));
175
176// fprintf(stderr, "key=%d mkey=%d i_count=%d i_bit=%d i_mask=%llx max_key=%d is_in_bit=%llx\n", key, mkey, i_count, i_bit, i_mask,max_key, is_in_bit[i_count]);
177 if (key <= max_key) {
178 if (is_in_bit[i_count]&i_mask) {
179 unsigned int pos = counts[i_count] + (int)MED_POPCOUNT(is_in_bit[i_count]>>i_bit);
180 data[pos] = elem;
181 return 0;
182 }
183 else
184 if (key < max_key)
185 return -2; // elements were not inserted in the correct sorted order
186 // in key == max_key with bit 0 we simply have to insert the value
187 }
188
189 // new max key
190 unsigned int last_cnt = 0;
191 if (counts.size() > 0)
192 last_cnt = counts.back();
193 counts.resize(i_count+2, last_cnt);
194 is_in_bit.resize(i_count+1, 0);
195 is_in_bit[i_count] |= i_mask;
196 counts[i_count+1] = counts[i_count] + (int)MED_POPCOUNT(is_in_bit[i_count]);
197 data.push_back(elem);
198 max_key = key;
199 return 0;
200 }
201
202 //---------------------------------------------------------------
203 // get - NULL is returned for a key not inside.
204 //----------------------------------------------------------------
205 inline T *get(unsigned int key) {
206 return &data[get_ind(key)];
207 }
208
209
210 //---------------------------------------------------------------
211 // get_all_keys
212 //---------------------------------------------------------------
213 int get_all_keys(vector<unsigned int> &keys) {
214
215 keys.resize(data.size()-1);
216 unsigned int j=0;
217
218 //unsigned int base = min_val;
219 //for (int i=0; i<counts.size(); i++) {
220 // unsigned long long bits = is_in_bit[i];
221 // if (bits != 0)
222 // for (int k=63; k>=0; k--) {
223 // unsigned long long kbits = bits>>k;
224 // if (kbits & 0x1) {
225 // keys[j++] = base + 63 - k;
226 // }
227 // }
228 // base += 64;
229 //}
230
231
232 for (unsigned int i=min_val; i<=max_key; i++)
233 if (get_ind(i) > 0)
234 keys[j++] = i;
235
236 return 0;
237 }
238
239 //---------------------------------------------------------------
240 // get_all_intersected key:
241 // gets a uniq list of in_keys
242 // outputs:
243 // keys - the keys in the list that are also in in_keys
244 // inds - the indexes for these keys (in a vector of the same size)
245 //---------------------------------------------------------------
246 int get_all_intersected_keys(const vector<int> &in_keys, vector<int> &keys, vector<int> &inds) {
247
248 keys.resize(in_keys.size());
249 inds.resize(in_keys.size());
250 int i_size = 0;
251
252 for (int i=0; i<in_keys.size(); i++) {
253 int ind = get_ind(in_keys[i]);
254 if (ind > 0) {
255 keys[i_size] = in_keys[i];
256 inds[i_size] = ind;
257 i_size++;
258 }
259 }
260/*
261 for (int i=0; i<in_keys.size(); i++) {
262 unsigned int curr_key = in_keys[i];
263 //fprintf(stderr, "working on curr_key %d ind %d min_val %d max_key %d\n", curr_key, get_ind(curr_key), min_val, max_key);
264 if (curr_key >= min_val && curr_key <= max_key) {
265 unsigned int mkey = curr_key - min_val;
266 int i_count = mkey>>6;
267 int k = 63 - (mkey & 0x3f);
268 unsigned long long bits = is_in_bit[i_count];
269 if (bits) {
270 unsigned long long kbits = bits >> k;
271 //fprintf(stderr, "mkey %d i_count %d k %d bits %d\n", mkey, i_count, k, bits);
272 if (kbits &0x1) {
273 keys[i_size] = curr_key;
274 inds[i_size] = counts[i_count] + (int)_mm_popcnt_u64(kbits);
275 i_size++;
276 }
277 }
278 }
279 }
280*/
281 keys.resize(i_size);
282 inds.resize(i_size);
283
284 //fprintf(stderr, "i_size = %d/%d %d %d \n", i_size, in_keys.size(), keys.size(), inds.size()); fflush(stderr);
285 return 0;
286 }
287 //---------------------------------------------------------------
288 // Serializations
289 //---------------------------------------------------------------
290 size_t get_size() {
291 size_t size = 0;
292
293 size += sizeof(unsigned long long); // len of serialization
294 size += sizeof(unsigned long long); // magic number recognizer
295 size += sizeof(unsigned int); // min_val
296 size += sizeof(unsigned int); // max_val
297 size += sizeof(int); // max_set
298 size += sizeof(unsigned int); // max_key
299 size += sizeof(T); // def_val
300 size += sizeof(unsigned int); // len counts
301 size += sizeof(unsigned int) * counts.size(); // counts vector
302 size += sizeof(unsigned int); // len is_in_bit
303 size += sizeof(unsigned long long) * is_in_bit.size(); // is_in_bit vector
304 size += sizeof(unsigned int); // len data
305 size += sizeof(T) * data.size();
306
307 return size;
308
309 }
310
311 //---------------------------------------------------------------
312 size_t serialize(unsigned char *blob) {
313
314 unsigned char *curr = blob;
315
316 curr += sizeof(unsigned long long); // bypassing len - will place it at the end.
317 ((unsigned long long *)curr)[0] = (unsigned long long)MED_SPARSE_VEC_MAGIC_NUM; curr+= sizeof(unsigned long long);
318 ((unsigned int *)curr)[0] = min_val; curr+= sizeof(unsigned int);
319 ((unsigned int *)curr)[0] = max_val; curr+= sizeof(unsigned int);
320 ((int *)curr)[0] = max_set; curr+= sizeof(int);
321 ((unsigned int *)curr)[0] = max_key; curr+= sizeof(unsigned int);
322 ((T *)curr)[0] = def_val; curr+= sizeof(T);
323 ((unsigned int *)curr)[0] = (unsigned int)counts.size(); curr+= sizeof(unsigned int);
324 for (int i=0; i<counts.size(); i++) {
325 ((unsigned int *)curr)[0] = counts[i]; curr+= sizeof(unsigned int);
326 }
327 ((unsigned int *)curr)[0] = (unsigned int)is_in_bit.size(); curr+= sizeof(unsigned int);
328 for (int i=0; i<is_in_bit.size(); i++) {
329 ((unsigned long long *)curr)[0] = is_in_bit[i]; curr+= sizeof(unsigned long long);
330 }
331 ((unsigned int *)curr)[0] = (unsigned int)data.size(); curr+= sizeof(unsigned int);
332 for (int i=0; i<data.size(); i++) {
333 ((T *)curr)[0] = data[i]; curr+= sizeof(T);
334 }
335 unsigned long long len = (unsigned long long)(curr-blob);
336 ((unsigned long long *)blob)[0] = len;
337 return len;
338
339 }
340
341 //---------------------------------------------------------------
342 size_t deserialize(unsigned char *blob) {
343
344 unsigned char *curr = blob;
345
346 counts.clear();
347 is_in_bit.clear();
348 data.clear();
349 unsigned long long serialize_len = ((unsigned long long *)curr)[0]; curr += sizeof(unsigned long long);
350 unsigned long long magic_num = ((unsigned long long *)curr)[0]; curr += sizeof(unsigned long long);
351 if (magic_num != (unsigned long long)MED_SPARSE_VEC_MAGIC_NUM) {
352 fprintf(stderr, "ERROR: Sparse Vec Magic Num wrong : can't deserialize() (%llx)\n", magic_num);
353 return 0;
354 }
355 min_val = ((unsigned int *)curr)[0]; curr += sizeof(unsigned int);
356 max_val = ((unsigned int *)curr)[0]; curr += sizeof(unsigned int);
357 max_set = ((int *)curr)[0]; curr += sizeof(int);
358 max_key = ((unsigned int *)curr)[0]; curr += sizeof(unsigned int);
359 def_val = ((T *)curr)[0]; curr += sizeof(T);
360 unsigned int len_counts = ((unsigned int *)curr)[0]; curr += sizeof(unsigned int);
361 counts.resize(len_counts);
362 memcpy(&counts[0], curr, len_counts*sizeof(unsigned int));
363 curr += sizeof(unsigned int) * len_counts;
364
365 unsigned int len_is_in_bit = ((unsigned int *)curr)[0]; curr += sizeof(unsigned int);
366 is_in_bit.resize(len_is_in_bit);
367 memcpy(&is_in_bit[0], curr, len_is_in_bit*sizeof(unsigned long long));
368 curr += sizeof(unsigned long long)*len_is_in_bit;
369
370 unsigned int len_data = ((unsigned int *)curr)[0]; curr += sizeof(unsigned int);
371 data.resize(len_data);
372 memcpy(&data[0], curr, len_data*sizeof(T));
373 curr += sizeof(T) * len_data;
374
375 size_t len = curr - blob;
376 if (len != serialize_len) {
377 fprintf(stderr, "ERROR: Sparse Vec serialize len not matching decalred one: %zu != %llu\n", len, serialize_len);
378 return 0;
379 }
380 return len;
381 }
382
383};
384
385#endif
Definition MedSparseVec.h:82
Definition BFloat16.h:88