GENIE
match_device_utils.h
Go to the documentation of this file.
1 
11 #ifndef GENIE_MATCH_DEVICE_UTILS_H
12 #define GENIE_MATCH_DEVICE_UTILS_H
13 
14 #include <genie/query/query.h>
15 #include "match_common.h"
16 
17 namespace genie
18 {
19 namespace matching
20 {
21 
22 const T_AGE MAX_AGE = 16u;
23 const uint32_t KEY_TYPE_BITS = 28u;
24 const uint32_t KEY_TYPE_MASK = u32(u64((1ull) << KEY_TYPE_BITS) - 1u);
25 const uint32_t ATTACH_ID_TYPE_BITS = 32u;
26 const uint32_t ATTACH_ID_TYPE_MASK = u32(u64((1ull) << ATTACH_ID_TYPE_BITS) - 1ul);
27 const uint32_t KEY_TYPE_INIT_AGE = 1u;
28 const uint32_t KEY_TYPE_NULL_AGE = 0u;
29 
30 static const uint32_t h_offsets[] =
31  { 0u, 3949349u, 8984219u, 9805709u, 7732727u, 1046459u, 9883879u, 4889399u,
32  2914183u, 3503623u, 1734349u, 8860463u, 1326319u, 1613597u, 8604269u, 9647369u};
33 
34 static __device__ __constant__ u32 d_offsets[16];
35 
36 __forceinline__ __host__ __device__ T_KEY get_key_pos(T_HASHTABLE key)
37 {
38  return key & KEY_TYPE_MASK;
39 }
40 
41 __forceinline__ __host__ __device__ T_AGE get_key_age(T_HASHTABLE key)
42 {
43  return ((key) >> (ATTACH_ID_TYPE_BITS + KEY_TYPE_BITS));
44 }
45 
46 __host__ __forceinline__ __device__
47 u32 get_key_attach_id(T_HASHTABLE key) //to get the count of one item
48 {
49  return ((key) >> (KEY_TYPE_BITS)) & ATTACH_ID_TYPE_MASK;
50 }
51 __host__ __forceinline__ __device__
53 {
54  return ((p) & KEY_TYPE_MASK);
55 }
56 __host__ __forceinline__ __device__
58 {
59  return u64(
60  ((u64(a) << (ATTACH_ID_TYPE_BITS + KEY_TYPE_BITS)))
61  + ((u64(i) & ATTACH_ID_TYPE_MASK) << (KEY_TYPE_BITS))
62  + u64(p & KEY_TYPE_MASK));
63 }
64 
65 __forceinline__ __device__ u32 hash(T_KEY key, T_AGE age,
66  int hash_table_size)
67 {
68  return (d_offsets[age] + key) % hash_table_size;
69 }
70 
71 __forceinline__ __device__ __host__
72 void print_binary(char * b, u32 data)
73 {
74  for (int i = 31; i >= 0; i--)
75  b[31 - i] = ((data >> i) & 1) == 1 ? '1' : '0';
76  b[32] = '\0';
77 }
78 
79 __forceinline__ __device__ __host__ u32
80 get_count(u32 data, int offset, int bits)
81 {
82  return (data >> offset) & ((1u << bits) - 1u);
83 }
84 
85 __forceinline__ __device__ __host__ u32
86 pack_count(u32 data, int offset, int bits, u32 count)
87 {
88  u32 r;
89  r = data & (~(((1u << bits) - 1u) << offset));
90  r |= (count << offset);
91  return r;
92 }
93 __forceinline__ __device__
94 void access_kernel(u32 id, T_HASHTABLE* htable, int hash_table_size,
95  genie::query::Query::dim& q, bool * key_found)
96 {
97  u32 location;
98  T_HASHTABLE out_key, new_key;
100 
101  location = hash(id, age, hash_table_size);
102 
103  while (1)
104  {
105  out_key = htable[location];
106 
107  if (get_key_pos(out_key)
108  == id && get_key_age(out_key) != KEY_TYPE_NULL_AGE
109  && get_key_age(out_key) < MAX_AGE)
110  {
111  u32 attach_id = get_key_attach_id(out_key);
112  float old_value_plus = *reinterpret_cast<float*>(&attach_id) + q.weight;
114  *reinterpret_cast<u32*>(&old_value_plus),
115  get_key_age(out_key));
116  if(atomicCAS(&htable[location], out_key, new_key) == out_key)
117  {
118  *key_found =true;
119  return;
120  }
121  }
122  else
123  {
124  break;
125  }
126  }
127 
128  while (age < MAX_AGE)
129  {
130  age++;
131  location = hash(id, age, hash_table_size);
132  out_key = htable[location];
133 
134  if (get_key_pos(out_key)
135  == id && get_key_age(out_key) != KEY_TYPE_NULL_AGE
136  && get_key_age(out_key) < MAX_AGE)
137  {
138  u32 attach_id = get_key_attach_id(out_key);
139  float old_value_plus = *reinterpret_cast<float*>(&attach_id) + q.weight;
141  *reinterpret_cast<u32*>(&old_value_plus),
142  get_key_age(out_key));
143  if(atomicCAS(&htable[location], out_key, new_key) == out_key)
144  {
145  *key_found =true;
146  return;
147  }
148  else
149  {
150  age --;
151  continue;
152  }
153  }
154  }
155  //Entry not found.
156  *key_found = 0;
157 }
158 
159 //for AT: for adaptiveThreshold
160 __device__ __forceinline__ void
161 access_kernel_AT(u32 id, T_HASHTABLE* htable, int hash_table_size,
162  genie::query::Query::dim& q, u32 count, bool * key_found, u32* my_threshold,
163  bool * pass_threshold // if the count smaller that my_threshold, do not insert
164  )
165 {
166  u32 location;
167  T_HASHTABLE out_key, new_key;
168  T_AGE age = KEY_TYPE_NULL_AGE;
169 
170  location = hash(id, age, hash_table_size);
171  while (1)
172  {
173  out_key = htable[location];
174 
175  if (get_key_pos(out_key)
176  == id && get_key_age(out_key) != KEY_TYPE_NULL_AGE
177  && get_key_age(out_key) < MAX_AGE)
178  {
179  u32 attach_id = get_key_attach_id(out_key); //for AT: for adaptiveThreshold
180  float value_1 = *reinterpret_cast<float*>(&attach_id);//for AT: for adaptiveThreshold
181  float value_plus = count;//for AT: for adaptiveThreshold
182  if(value_plus <value_1)
183  { //for AT: for adaptiveThreshold
184  *pass_threshold = true;// still need to update the my_threshold and passCount
185  *key_found =true;//already find the key, but do not update
186  return;
187  }
189  *reinterpret_cast<u32*>(&value_plus),
190  get_key_age(out_key));
191  if(value_plus<*my_threshold)
192  {
193  *pass_threshold = false; // if my_threshold is updated, no need to update hash_table and threshold
194  *key_found =true;//already find the key, but do not update
195  return;
196  }
197  if(atomicCAS(&htable[location], out_key, new_key) == out_key)
198  { *pass_threshold = true; //high possible that pass the threshold, must update the threshold
199  *key_found =true;
200  return;
201  }
202  }
203  else
204  {
205  break;
206  }
207  }
208 
209  while (age < MAX_AGE)
210  {
211  age++;
212  location = hash(id, age, hash_table_size);
213  out_key = htable[location];
214 
215  if (get_key_pos(out_key)
216  == id && get_key_age(out_key) != KEY_TYPE_NULL_AGE
217  && get_key_age(out_key) < MAX_AGE)
218  {
219  u32 attach_id = get_key_attach_id(out_key); //for AT: for adaptiveThreshold
220 
221  float value_1 = *reinterpret_cast<float*>(&attach_id);//for AT: for adaptiveThreshold //for improve: update here for weighted distance
222  float value_plus = count;//for AT: for adaptiveThreshold
223  if(value_plus <value_1)
224  { //for AT: for adaptiveThreshold
225  *pass_threshold = true;// still need to update the my_threshold and passCount
226  *key_found =true;//already find the key, but do not update
227  return;
228  }
229 
231  *reinterpret_cast<u32*>(&value_plus), //for impprove:update here for weighted distance
232  get_key_age(out_key));
233  if(value_plus<*my_threshold)
234  {
235  *pass_threshold = false; // if my_threshold is updated, no need to update hash_table and threshold
236  *key_found =true;//already find the key, but do not update
237  return;
238  }
239  if(atomicCAS(&htable[location], out_key, new_key) == out_key)
240  {
241  *pass_threshold = true;
242  *key_found =true;
243  return;
244  }
245  else
246  {
247  age --;
248  continue;
249  }
250  }
251  }
252 
253  *key_found = false;
254  //key not found, no need to update my_threshold
255  *pass_threshold = false;
256 }
257 
258 //for AT: for countHeap (with adaptiveThreshold)
259 __device__ __forceinline__ void
261  u32 id,
262  T_HASHTABLE* htable, int hash_table_size, genie::query::Query::dim& q, u32 count,
263  u32* my_threshold, //for AT: for adaptiveThreshold, if the count is smaller than my_threshold, this item is also expired in the hashTable
264  u32 * my_noiih, bool * overflow, bool* pass_threshold)
265 {
266  u32 location;
267  T_HASHTABLE evicted_key, peek_key;
268  T_AGE age = KEY_TYPE_NULL_AGE;
269  float count_value = count;
271  //*reinterpret_cast<u32*>(&(q.weight)),//for AT: for adaptiveThreshold
272  *reinterpret_cast<u32*>(&count_value), KEY_TYPE_INIT_AGE);
273  //Loop until MAX_AGE
274  while (age < MAX_AGE)
275  {
276 
277  //evict key at current age-location
278  //Update it if the to-be-inserted key is of a larger age
279  u32 key_attach_id = get_key_attach_id(key); //for AT: for daptiveThreshold for ask: what is old_value_1, and what is old_value_2
280  float key_value = *reinterpret_cast<float*>(&key_attach_id);
281  if (key_value < *my_threshold)
282  { //no need to update
283  if (get_key_pos(key) == id)
284  {
285  *pass_threshold = false; // if the item is expired because my_threshold is increased, no need to update hash_table and threshold by this data item
286  }
287  else
288  {
289  *pass_threshold = true; //the id has been inserted into hashtable, this key_attach_id is from the evicted_key
290  }
291  return;
292  }
293 
294  location = hash(get_key_pos(key), age, hash_table_size);
295  while (1)
296  {
297  if (*my_noiih > hash_table_size)
298  {
299  *overflow = true;
300  return;
301  }
302 
303  peek_key = htable[location];
304  u32 peek_key_attach_id = get_key_attach_id(peek_key); //for AT: for adaptiveThreshold
305  float peek_key_value =
306  *reinterpret_cast<float*>(&peek_key_attach_id);
307  if (get_key_pos(peek_key) == get_key_pos(key)
308  && get_key_age(peek_key) != 0u) //even previously key_eligible ==0, the key may be inserted by other threads
309  {
310 
311  //float old_value_plus = (old_value_1>old_value_2)? (*reinterpret_cast<float*>(&old_value_1)) : (*reinterpret_cast<float*>(&old_value_2));//for AT: for adaptiveThreshold
312 
313  //float old_value_plus = (old_value_1>old_value_2)? (old_value_1) : (old_value_2);//for AT: for adaptiveThreshold
314  if (key_value < peek_key_value)
315  { //no need to update
316  *pass_threshold = true; // still need to update the my_threshold and passCount
317  return;
318  }
319 
321  get_key_pos(peek_key),
322  *reinterpret_cast<u32*>(&key_value), //for improve: update here for weighted distance
323  get_key_age(peek_key));
324 
325  if (key_value < *my_threshold)
326  { //no need to update
327  if (get_key_pos(key) == id)
328  {
329  *pass_threshold = false; // if the item is expired because my_threshold is increased, no need to update hash_table and threshold by this data item
330  }
331  else
332  {
333  *pass_threshold = true; //the id has been inserted into hashtable, this key_attach_id is from the evicted_key
334  }
335  return;
336  }
337  if (atomicCAS(&htable[location], peek_key, new_key) == peek_key)
338  {
339 
340  *pass_threshold = true; //after updat the hashtable, increase the pass_count and my_threshold
341  return;
342  }
343  else
344  {
345  continue;
346  }
347  }
348 
349  if ((get_key_age(peek_key) < get_key_age(key) //if this location with smaller age (inclusive empty location, i.e. age 0)
350  || (get_key_age(peek_key) != KEY_TYPE_NULL_AGE
351  && peek_key_value < *my_threshold)) //for AT: for adaptiveThreshold, if the count is smaller than my_threshold,
352  //this item is also expired in the hashTable,
353  )
354  {
355  if (key_value < *my_threshold)
356  { //no need to update
357  if (get_key_pos(key) == id)
358  {
359  *pass_threshold = false; // if the item is expired because my_threshold is increased, no need to update hash_table and threshold by this data item
360  }
361  else
362  {
363  *pass_threshold = true; //the id has been inserted into hashtable, this key_attach_id is from the evicted_key
364  }
365  return;
366  }
367 
368  evicted_key = atomicCAS(&htable[location], peek_key, key);
369 
370  if (evicted_key != peek_key)
371  continue;
372 
373  if ((get_key_age(evicted_key) > 0u) //if this not an empty location
374  )
375  {
376  if (peek_key_value < *my_threshold)
377  { // for AT: for adaptiveThreshold, if the count is smaller than my_threshold,
378  //this item is also expired in the hashTable,
379  *pass_threshold = true; //after updating the hashtable, increase the pass_count and my_threshold
380  return;
381  }
382 
383  key = evicted_key;
384  age = get_key_age(evicted_key);
385 
386  break;
387  }
388  else//if get_key_age(evicted_key) == 0, this is empty insertion, nothing need to do
389  {
390 
391  if (*my_noiih >= hash_table_size)
392  {
393  *overflow = true;
394  atomicAdd(my_noiih, 1u);// this will not affect the performance very much
395  return;
396  }
397  else
398  {
399  atomicAdd(my_noiih, 1u);// this will not affect the performance very much
400  }
401  *pass_threshold = true; //after updating the hashtable, increase the pass_count and my_threshold
402 
403  return; //finish insertion for empty location
404  }
405  }
406  else
407  {
408  age++;
410  get_key_attach_id(key), age);
411  break;
412  }
413  }
414 
415  }
416  *overflow = true;
417  *pass_threshold = true;
418  return;
419 }
420 
421 //for AT: for adaptiveThreshold, this is function for bitmap
422 __device__ __forceinline__ u32
423 bitmap_kernel_AT(u32 access_id, u32 * bitmap, int bits, int my_threshold,
424  bool * key_eligible)
425 {
426  u32 value, count = 0, new_value;
427  int offset;
428 
429  // This loop attemps to increase the count at the corresponding location in the bitmap array (this array counts
430  // the docIDs masked by first "bits" bits) until the increase is successfull, sincemany threads may be accessing
431  // this bitmap array in parallel.
432  while (1)
433  {
434  value = bitmap[access_id / (32 / bits)]; // Current value
435  offset = (access_id % (32 / bits)) * bits;
436  count = get_count(value, offset, bits);
437  count = count + 1; //always maintain the count in bitmap//for improve: change here for weighted distance
438  *key_eligible = count >= my_threshold;
439  new_value = pack_count(value, offset, bits, count);
440  if (atomicCAS(&bitmap[access_id / (32 / bits)], value, new_value)
441  == value)
442  break;
443  }
444  return count; //fail to access the count
445 
446 }
447 
448 __device__ __forceinline__ void
449 updateThreshold(u32* my_passCount, u32* my_threshold,
450  u32 my_topk, u32 count)
451 {
452  if (count < *my_threshold)
453  {
454  return; //threshold has been increased, no need to update
455  }
456  atomicAdd(&my_passCount[count], 1); //successfully update
457 
458  u32 this_threshold = (*my_threshold);
459 
460  while (true)
461  {
462  this_threshold = *my_threshold;
463  if (my_passCount[this_threshold] >= my_topk)
464  {
465  this_threshold = atomicCAS(my_threshold, this_threshold,
466  this_threshold + 1);
467  }
468  else
469  {
470  break;
471  }
472  }
473 }
474 
475 } // namespace matching
476 } // namespace genie
477 
478 #endif
__device__ __forceinline__ void access_kernel_AT(u32 id, T_HASHTABLE *htable, int hash_table_size, genie::query::Query::dim &q, u32 count, bool *key_found, u32 *my_threshold, bool *pass_threshold)
const uint32_t KEY_TYPE_INIT_AGE
This is the top-level namespace of the project.
__forceinline__ __device__ u32 hash(T_KEY key, T_AGE age, int hash_table_size)
__device__ __forceinline__ void hash_kernel_AT(u32 id, T_HASHTABLE *htable, int hash_table_size, genie::query::Query::dim &q, u32 count, u32 *my_threshold, u32 *my_noiih, bool *overflow, bool *pass_threshold)
__host__ __forceinline__ __device__ T_HASHTABLE pack_key_pos(T_KEY p)
__forceinline__ __device__ __host__ void print_binary(char *b, u32 data)
const uint32_t ATTACH_ID_TYPE_MASK
const uint32_t ATTACH_ID_TYPE_BITS
u32 T_KEY
Definition: match_common.h:22
Declaration of query class.
__host__ __forceinline__ __device__ u32 get_key_attach_id(T_HASHTABLE key)
__forceinline__ __device__ __host__ u32 pack_count(u32 data, int offset, int bits, u32 count)
__forceinline__ __host__ __device__ T_AGE get_key_age(T_HASHTABLE key)
__forceinline__ __host__ __device__ T_KEY get_key_pos(T_HASHTABLE key)
__forceinline__ __device__ __host__ u32 get_count(u32 data, int offset, int bits)
const uint32_t KEY_TYPE_MASK
__forceinline__ __device__ void access_kernel(u32 id, T_HASHTABLE *htable, int hash_table_size, genie::query::Query::dim &q, bool *key_found)
__device__ __forceinline__ u32 bitmap_kernel_AT(u32 access_id, u32 *bitmap, int bits, int my_threshold, bool *key_eligible)
uint32_t u32
Definition: match_common.h:18
unsigned long long u64
A type definition for a 64-bit unsigned integer.
Definition: match_common.h:19
The second-step struct for processing queries.
Definition: query.h:59
__host__ __forceinline__ __device__ T_HASHTABLE pack_key_pos_and_attach_id_and_age(T_KEY p, u32 i, T_AGE a)
Basic utility functions to be used in matching kernels.
u64 T_HASHTABLE
Definition: match_common.h:21
const uint32_t KEY_TYPE_NULL_AGE
const uint32_t KEY_TYPE_BITS
__device__ __forceinline__ void updateThreshold(u32 *my_passCount, u32 *my_threshold, u32 my_topk, u32 count)
u32 T_AGE
Definition: match_common.h:23