From 4890bdb36441aabbd80578fc66441045a31100e1 Mon Sep 17 00:00:00 2001 From: Peter Law Date: Tue, 30 Jan 2024 20:43:56 +0000 Subject: [PATCH] Improve type handling in rich comparisons (#111) Previously an error was raised when doing equality comparisons between BitMaps and other instance types, which is unexpected in Python. Move over to implementing each of the rich comparison methods directly so that we can more match the usual pattern of allowed (and rejected) comparisons. Also explicitly check for `None` when validating that two BitMaps can be compared, so that the error message emitted in that case is of the more expected type (i.e: `TypeError` rather than `AttributeError`). --- pyroaring/abstract_bitmap.pxi | 42 +++++++++++++++++++++++------------ test.py | 29 +++++++++++++++++++++++- 2 files changed, 56 insertions(+), 15 deletions(-) diff --git a/pyroaring/abstract_bitmap.pxi b/pyroaring/abstract_bitmap.pxi index 1db09bc..2e7419a 100644 --- a/pyroaring/abstract_bitmap.pxi +++ b/pyroaring/abstract_bitmap.pxi @@ -116,6 +116,8 @@ cdef class AbstractBitMap: croaring.roaring_bitmap_free(self._c_bitmap) def _check_compatibility(self, AbstractBitMap other): + if other is None: + raise TypeError('Argument has incorrect type (expected pyroaring.AbstractBitMap, got None)') if self.copy_on_write != other.copy_on_write: raise ValueError('Cannot have interactions between bitmaps with and without copy_on_write.\n') @@ -128,21 +130,33 @@ cdef class AbstractBitMap: def __len__(self): return croaring.roaring_bitmap_get_cardinality(self._c_bitmap) - def __richcmp__(self, other, int op): + def __lt__(self, AbstractBitMap other): self._check_compatibility(other) - if op == 0: # < - return croaring.roaring_bitmap_is_strict_subset((self)._c_bitmap, (other)._c_bitmap) - elif op == 1: # <= - return croaring.roaring_bitmap_is_subset((self)._c_bitmap, (other)._c_bitmap) - elif op == 2: # == - return croaring.roaring_bitmap_equals((self)._c_bitmap, (other)._c_bitmap) - elif op == 3: # != - return not (self == other) - elif op == 4: # > - return croaring.roaring_bitmap_is_strict_subset((other)._c_bitmap, (self)._c_bitmap) - else: # >= - assert op == 5 - return croaring.roaring_bitmap_is_subset((other)._c_bitmap, (self)._c_bitmap) + return croaring.roaring_bitmap_is_strict_subset((self)._c_bitmap, (other)._c_bitmap) + + def __le__(self, AbstractBitMap other): + self._check_compatibility(other) + return croaring.roaring_bitmap_is_subset((self)._c_bitmap, (other)._c_bitmap) + + def __eq__(self, object other): + if not isinstance(other, AbstractBitMap): + return NotImplemented + self._check_compatibility(other) + return croaring.roaring_bitmap_equals((self)._c_bitmap, (other)._c_bitmap) + + def __ne__(self, object other): + if not isinstance(other, AbstractBitMap): + return NotImplemented + self._check_compatibility(other) + return not croaring.roaring_bitmap_equals((self)._c_bitmap, (other)._c_bitmap) + + def __gt__(self, AbstractBitMap other): + self._check_compatibility(other) + return croaring.roaring_bitmap_is_strict_subset((other)._c_bitmap, (self)._c_bitmap) + + def __ge__(self, AbstractBitMap other): + self._check_compatibility(other) + return croaring.roaring_bitmap_is_subset((other)._c_bitmap, (self)._c_bitmap) def contains_range(self, uint64_t range_start, uint64_t range_end): """ diff --git a/test.py b/test.py index 8724c21..0131eb9 100755 --- a/test.py +++ b/test.py @@ -528,7 +528,7 @@ def test_comparison( values2: HypCollection, cow: bool, ) -> None: - for op in [operator.le, operator.ge, operator.lt, operator.gt]: + for op in [operator.le, operator.ge, operator.lt, operator.gt, operator.eq, operator.ne]: self.set1 = set(values1) self.set2 = set(values2) self.bitmap1 = cls1(values1, copy_on_write=cow) @@ -542,6 +542,15 @@ def test_comparison( self.assertEqual(op(self.set1, self.set1 | self.set2), op(self.set1, self.set1 | self.set2)) + @given(bitmap_cls, hyp_collection, st.booleans()) + def test_comparison_other_objects(self, cls: type[EitherBitMap], values: HypCollection, cow: bool) -> None: + for op in [operator.le, operator.ge, operator.lt, operator.gt]: + bm = cls(values, copy_on_write=cow) + with self.assertRaises(TypeError): + op(bm, 42) + with self.assertRaises(TypeError): + op(bm, None) + @given(bitmap_cls, bitmap_cls, hyp_collection, hyp_collection, st.booleans()) def test_intersect( self, @@ -555,6 +564,24 @@ def test_intersect( bm2 = cls2(values2, copy_on_write=cow) self.assertEqual(bm1.intersect(bm2), len(bm1 & bm2) > 0) + @given(bitmap_cls, hyp_collection, st.booleans()) + def test_eq_other_objects(self, cls: type[EitherBitMap], values: HypCollection, cow: bool) -> None: + bm = cls(values, copy_on_write=cow) + + self.assertFalse(bm == 42) + self.assertIs(cls.__eq__(bm, 42), NotImplemented) + self.assertFalse(bm == None) # noqa: E711 + self.assertIs(cls.__eq__(bm, None), NotImplemented) + + @given(bitmap_cls, hyp_collection, st.booleans()) + def test_ne_other_objects(self, cls: type[EitherBitMap], values: HypCollection, cow: bool) -> None: + bm = cls(values, copy_on_write=cow) + + self.assertTrue(bm != 42) + self.assertIs(cls.__ne__(bm, 42), NotImplemented) + self.assertTrue(bm != None) # noqa: E711 + self.assertIs(cls.__ne__(bm, None), NotImplemented) + class RangeTest(Util): @given(bitmap_cls, hyp_collection, st.booleans(), uint32, uint32)