Skip to content

Commit

Permalink
Merge branch 'mainline' into c-value-store
Browse files Browse the repository at this point in the history
  • Loading branch information
hallogameboy authored Oct 3, 2023
2 parents b123b04 + cb6b067 commit 015d5e4
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
7 changes: 7 additions & 0 deletions pecos/utils/mmap_hashmap_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,13 @@ def get(self, keys, default_val):
ii) int2int: 1D numpy array of int64
2) The return is a reused buffer, use or copy the data once you get it. It is not guaranteed to last.
"""

if len(keys) > self.max_batch_size:
self.max_batch_size = max(len(keys), 2 * self.max_batch_size)
self.key_prealloc = self.mmap_r.get_keyalloc(self.max_batch_size)
self.vals = np.zeros(self.max_batch_size, dtype=np.uint64)
LOGGER.info(f"Increased the max batch size to {self.max_batch_size}")

self.mmap_r.batch_get(
len(keys),
self.key_prealloc.get_key_prealloc(keys),
Expand Down
12 changes: 10 additions & 2 deletions test/pecos/utils/test_mmap_hashmap_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,12 @@ def test_str2int_mmap_hashmap(tmpdir):
) # Non-exist key
vs = list(kv_dict.values()) + [10] * (max_batch_size - len(kv_dict))
assert r_map_batch_getter.get(ks, 10).tolist() == vs
# Cannot test for max_batch_size < num of key, will result in segmentation fault
# max_batch_size = num of key * 3
ks = list(kv_dict.keys()) + ["ccccc".encode("utf-8")] * (
3 * max_batch_size - len(kv_dict)
) # Non-exist key
vs = list(kv_dict.values()) + [10] * (3 * max_batch_size - len(kv_dict))
assert r_map_batch_getter.get(ks, 10).tolist() == vs


def test_int2int_mmap_hashmap(tmpdir):
Expand Down Expand Up @@ -107,4 +112,7 @@ def test_int2int_mmap_hashmap(tmpdir):
ks = list(kv_dict.keys()) + [1000] * (max_batch_size - len(kv_dict)) # Non-exist key
vs = list(kv_dict.values()) + [10] * (max_batch_size - len(kv_dict))
assert r_map_batch_getter.get(np.array(ks, dtype=np.int64), 10).tolist() == vs
# Cannot test for max_batch_size < num of key, will result in segmentation fault
# max_batch_size = num of key * 3
ks = list(kv_dict.keys()) + [1000] * (3 * max_batch_size - len(kv_dict)) # Non-exist key
vs = list(kv_dict.values()) + [10] * (3 * max_batch_size - len(kv_dict))
assert r_map_batch_getter.get(np.array(ks, dtype=np.int64), 10).tolist() == vs

0 comments on commit 015d5e4

Please sign in to comment.