Skip to content

Commit

Permalink
Improve type handling in rich comparisons (#111)
Browse files Browse the repository at this point in the history
Previously an error was raised when doing equality comparisons
between BitMaps and other instance types, which is unexpected in
Python. Move over to implementing each of the rich comparison
methods directly so that we can more match the usual pattern of
allowed (and rejected) comparisons.

Also explicitly check for `None` when validating that two BitMaps
can be compared, so that the error message emitted in that case
is of the more expected type (i.e: `TypeError` rather than
`AttributeError`).
  • Loading branch information
PeterJCLaw authored Jan 30, 2024
1 parent 678552b commit 4890bdb
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 15 deletions.
42 changes: 28 additions & 14 deletions pyroaring/abstract_bitmap.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ cdef class AbstractBitMap:
croaring.roaring_bitmap_free(self._c_bitmap)

def _check_compatibility(self, AbstractBitMap other):
if other is None:
raise TypeError('Argument has incorrect type (expected pyroaring.AbstractBitMap, got None)')
if self.copy_on_write != other.copy_on_write:
raise ValueError('Cannot have interactions between bitmaps with and without copy_on_write.\n')

Expand All @@ -128,21 +130,33 @@ cdef class AbstractBitMap:
def __len__(self):
return croaring.roaring_bitmap_get_cardinality(self._c_bitmap)

def __richcmp__(self, other, int op):
def __lt__(self, AbstractBitMap other):
self._check_compatibility(other)
if op == 0: # <
return croaring.roaring_bitmap_is_strict_subset((<AbstractBitMap?>self)._c_bitmap, (<AbstractBitMap?>other)._c_bitmap)
elif op == 1: # <=
return croaring.roaring_bitmap_is_subset((<AbstractBitMap?>self)._c_bitmap, (<AbstractBitMap?>other)._c_bitmap)
elif op == 2: # ==
return croaring.roaring_bitmap_equals((<AbstractBitMap?>self)._c_bitmap, (<AbstractBitMap?>other)._c_bitmap)
elif op == 3: # !=
return not (self == other)
elif op == 4: # >
return croaring.roaring_bitmap_is_strict_subset((<AbstractBitMap?>other)._c_bitmap, (<AbstractBitMap?>self)._c_bitmap)
else: # >=
assert op == 5
return croaring.roaring_bitmap_is_subset((<AbstractBitMap?>other)._c_bitmap, (<AbstractBitMap?>self)._c_bitmap)
return croaring.roaring_bitmap_is_strict_subset((<AbstractBitMap?>self)._c_bitmap, (<AbstractBitMap?>other)._c_bitmap)

def __le__(self, AbstractBitMap other):
self._check_compatibility(other)
return croaring.roaring_bitmap_is_subset((<AbstractBitMap?>self)._c_bitmap, (<AbstractBitMap?>other)._c_bitmap)

def __eq__(self, object other):
if not isinstance(other, AbstractBitMap):
return NotImplemented
self._check_compatibility(other)
return croaring.roaring_bitmap_equals((<AbstractBitMap?>self)._c_bitmap, (<AbstractBitMap?>other)._c_bitmap)

def __ne__(self, object other):
if not isinstance(other, AbstractBitMap):
return NotImplemented
self._check_compatibility(other)
return not croaring.roaring_bitmap_equals((<AbstractBitMap?>self)._c_bitmap, (<AbstractBitMap?>other)._c_bitmap)

def __gt__(self, AbstractBitMap other):
self._check_compatibility(other)
return croaring.roaring_bitmap_is_strict_subset((<AbstractBitMap?>other)._c_bitmap, (<AbstractBitMap?>self)._c_bitmap)

def __ge__(self, AbstractBitMap other):
self._check_compatibility(other)
return croaring.roaring_bitmap_is_subset((<AbstractBitMap?>other)._c_bitmap, (<AbstractBitMap?>self)._c_bitmap)

def contains_range(self, uint64_t range_start, uint64_t range_end):
"""
Expand Down
29 changes: 28 additions & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def test_comparison(
values2: HypCollection,
cow: bool,
) -> None:
for op in [operator.le, operator.ge, operator.lt, operator.gt]:
for op in [operator.le, operator.ge, operator.lt, operator.gt, operator.eq, operator.ne]:
self.set1 = set(values1)
self.set2 = set(values2)
self.bitmap1 = cls1(values1, copy_on_write=cow)
Expand All @@ -542,6 +542,15 @@ def test_comparison(
self.assertEqual(op(self.set1, self.set1 | self.set2),
op(self.set1, self.set1 | self.set2))

@given(bitmap_cls, hyp_collection, st.booleans())
def test_comparison_other_objects(self, cls: type[EitherBitMap], values: HypCollection, cow: bool) -> None:
for op in [operator.le, operator.ge, operator.lt, operator.gt]:
bm = cls(values, copy_on_write=cow)
with self.assertRaises(TypeError):
op(bm, 42)
with self.assertRaises(TypeError):
op(bm, None)

@given(bitmap_cls, bitmap_cls, hyp_collection, hyp_collection, st.booleans())
def test_intersect(
self,
Expand All @@ -555,6 +564,24 @@ def test_intersect(
bm2 = cls2(values2, copy_on_write=cow)
self.assertEqual(bm1.intersect(bm2), len(bm1 & bm2) > 0)

@given(bitmap_cls, hyp_collection, st.booleans())
def test_eq_other_objects(self, cls: type[EitherBitMap], values: HypCollection, cow: bool) -> None:
bm = cls(values, copy_on_write=cow)

self.assertFalse(bm == 42)
self.assertIs(cls.__eq__(bm, 42), NotImplemented)
self.assertFalse(bm == None) # noqa: E711
self.assertIs(cls.__eq__(bm, None), NotImplemented)

@given(bitmap_cls, hyp_collection, st.booleans())
def test_ne_other_objects(self, cls: type[EitherBitMap], values: HypCollection, cow: bool) -> None:
bm = cls(values, copy_on_write=cow)

self.assertTrue(bm != 42)
self.assertIs(cls.__ne__(bm, 42), NotImplemented)
self.assertTrue(bm != None) # noqa: E711
self.assertIs(cls.__ne__(bm, None), NotImplemented)


class RangeTest(Util):
@given(bitmap_cls, hyp_collection, st.booleans(), uint32, uint32)
Expand Down

0 comments on commit 4890bdb

Please sign in to comment.