diff --git a/pyroaring/abstract_bitmap.pxi b/pyroaring/abstract_bitmap.pxi index 1db09bc..2e7419a 100644 --- a/pyroaring/abstract_bitmap.pxi +++ b/pyroaring/abstract_bitmap.pxi @@ -116,6 +116,8 @@ cdef class AbstractBitMap: croaring.roaring_bitmap_free(self._c_bitmap) def _check_compatibility(self, AbstractBitMap other): + if other is None: + raise TypeError('Argument has incorrect type (expected pyroaring.AbstractBitMap, got None)') if self.copy_on_write != other.copy_on_write: raise ValueError('Cannot have interactions between bitmaps with and without copy_on_write.\n') @@ -128,21 +130,33 @@ cdef class AbstractBitMap: def __len__(self): return croaring.roaring_bitmap_get_cardinality(self._c_bitmap) - def __richcmp__(self, other, int op): + def __lt__(self, AbstractBitMap other): self._check_compatibility(other) - if op == 0: # < - return croaring.roaring_bitmap_is_strict_subset((self)._c_bitmap, (other)._c_bitmap) - elif op == 1: # <= - return croaring.roaring_bitmap_is_subset((self)._c_bitmap, (other)._c_bitmap) - elif op == 2: # == - return croaring.roaring_bitmap_equals((self)._c_bitmap, (other)._c_bitmap) - elif op == 3: # != - return not (self == other) - elif op == 4: # > - return croaring.roaring_bitmap_is_strict_subset((other)._c_bitmap, (self)._c_bitmap) - else: # >= - assert op == 5 - return croaring.roaring_bitmap_is_subset((other)._c_bitmap, (self)._c_bitmap) + return croaring.roaring_bitmap_is_strict_subset((self)._c_bitmap, (other)._c_bitmap) + + def __le__(self, AbstractBitMap other): + self._check_compatibility(other) + return croaring.roaring_bitmap_is_subset((self)._c_bitmap, (other)._c_bitmap) + + def __eq__(self, object other): + if not isinstance(other, AbstractBitMap): + return NotImplemented + self._check_compatibility(other) + return croaring.roaring_bitmap_equals((self)._c_bitmap, (other)._c_bitmap) + + def __ne__(self, object other): + if not isinstance(other, AbstractBitMap): + return NotImplemented + self._check_compatibility(other) + return not croaring.roaring_bitmap_equals((self)._c_bitmap, (other)._c_bitmap) + + def __gt__(self, AbstractBitMap other): + self._check_compatibility(other) + return croaring.roaring_bitmap_is_strict_subset((other)._c_bitmap, (self)._c_bitmap) + + def __ge__(self, AbstractBitMap other): + self._check_compatibility(other) + return croaring.roaring_bitmap_is_subset((other)._c_bitmap, (self)._c_bitmap) def contains_range(self, uint64_t range_start, uint64_t range_end): """ diff --git a/test.py b/test.py index 8724c21..0131eb9 100755 --- a/test.py +++ b/test.py @@ -528,7 +528,7 @@ def test_comparison( values2: HypCollection, cow: bool, ) -> None: - for op in [operator.le, operator.ge, operator.lt, operator.gt]: + for op in [operator.le, operator.ge, operator.lt, operator.gt, operator.eq, operator.ne]: self.set1 = set(values1) self.set2 = set(values2) self.bitmap1 = cls1(values1, copy_on_write=cow) @@ -542,6 +542,15 @@ def test_comparison( self.assertEqual(op(self.set1, self.set1 | self.set2), op(self.set1, self.set1 | self.set2)) + @given(bitmap_cls, hyp_collection, st.booleans()) + def test_comparison_other_objects(self, cls: type[EitherBitMap], values: HypCollection, cow: bool) -> None: + for op in [operator.le, operator.ge, operator.lt, operator.gt]: + bm = cls(values, copy_on_write=cow) + with self.assertRaises(TypeError): + op(bm, 42) + with self.assertRaises(TypeError): + op(bm, None) + @given(bitmap_cls, bitmap_cls, hyp_collection, hyp_collection, st.booleans()) def test_intersect( self, @@ -555,6 +564,24 @@ def test_intersect( bm2 = cls2(values2, copy_on_write=cow) self.assertEqual(bm1.intersect(bm2), len(bm1 & bm2) > 0) + @given(bitmap_cls, hyp_collection, st.booleans()) + def test_eq_other_objects(self, cls: type[EitherBitMap], values: HypCollection, cow: bool) -> None: + bm = cls(values, copy_on_write=cow) + + self.assertFalse(bm == 42) + self.assertIs(cls.__eq__(bm, 42), NotImplemented) + self.assertFalse(bm == None) # noqa: E711 + self.assertIs(cls.__eq__(bm, None), NotImplemented) + + @given(bitmap_cls, hyp_collection, st.booleans()) + def test_ne_other_objects(self, cls: type[EitherBitMap], values: HypCollection, cow: bool) -> None: + bm = cls(values, copy_on_write=cow) + + self.assertTrue(bm != 42) + self.assertIs(cls.__ne__(bm, 42), NotImplemented) + self.assertTrue(bm != None) # noqa: E711 + self.assertIs(cls.__ne__(bm, None), NotImplemented) + class RangeTest(Util): @given(bitmap_cls, hyp_collection, st.booleans(), uint32, uint32)