|
| 1 | +from typing import Iterator, List, Optional, Tuple |
| 2 | + |
| 3 | +from pyroaring import BitMap as BitMap |
| 4 | + |
| 5 | + |
| 6 | +class RoaringPositionBitmap: |
| 7 | + BITMAP_COUNT_SIZE_BYTES: int = 8 |
| 8 | + BITMAP_KEY_SIZE_BYTES: int = 4 |
| 9 | + MAX_POSITION = ((2**31 - 2) << 32) | ((-(2**31)) & 0xFFFFFFFF) |
| 10 | + |
| 11 | + def __init__(self, bitmaps: Optional[List[BitMap]] = None): |
| 12 | + if bitmaps is not None: |
| 13 | + self.bitmaps = bitmaps |
| 14 | + else: |
| 15 | + self.bitmaps = [] |
| 16 | + |
| 17 | + def set(self, pos: int) -> None: |
| 18 | + self.validate_position(pos) |
| 19 | + key = self.key(pos) |
| 20 | + pos32 = self.pos32_bits(pos) |
| 21 | + self.allocate_bitmaps_if_needed(key + 1) |
| 22 | + self.bitmaps[key].add(pos32) |
| 23 | + |
| 24 | + def set_range(self, pos_start_inclusive: int, pos_end_exclusive: int) -> None: |
| 25 | + for pos in range(pos_start_inclusive, pos_end_exclusive): |
| 26 | + self.set(pos) |
| 27 | + |
| 28 | + def set_all(self, that: "RoaringPositionBitmap") -> None: |
| 29 | + self.allocate_bitmaps_if_needed(len(that.bitmaps)) |
| 30 | + for key in range(len(that.bitmaps)): |
| 31 | + self.bitmaps[key] |= that.bitmaps[key] |
| 32 | + |
| 33 | + def contains(self, pos: int) -> bool: |
| 34 | + self.validate_position(pos) |
| 35 | + key = self.key(pos) |
| 36 | + pos32 = self.pos32_bits(pos) |
| 37 | + return key < len(self.bitmaps) and pos32 in self.bitmaps[key] |
| 38 | + |
| 39 | + def is_empty(self) -> bool: |
| 40 | + return self.cardinality() == 0 |
| 41 | + |
| 42 | + def cardinality(self) -> int: |
| 43 | + return sum(len(bm) for bm in self.bitmaps) |
| 44 | + |
| 45 | + def run_length_encode(self) -> bool: |
| 46 | + changed = False |
| 47 | + for bm in self.bitmaps: |
| 48 | + changed |= bm.run_optimize() |
| 49 | + return changed |
| 50 | + |
| 51 | + def allocated_bitmap_count(self) -> int: |
| 52 | + return len(self.bitmaps) |
| 53 | + |
| 54 | + def allocate_bitmaps_if_needed(self, required_length: int) -> None: |
| 55 | + if len(self.bitmaps) < required_length: |
| 56 | + if len(self.bitmaps) == 0 and required_length == 1: |
| 57 | + self.bitmaps = [BitMap()] |
| 58 | + else: |
| 59 | + self.bitmaps = self.bitmaps + [BitMap() for _ in range(required_length - len(self.bitmaps))] |
| 60 | + |
| 61 | + def serialized_size_in_bytes(self) -> int: |
| 62 | + size = self.BITMAP_COUNT_SIZE_BYTES |
| 63 | + for bm in self.bitmaps: |
| 64 | + size += self.BITMAP_KEY_SIZE_BYTES + len(bm.serialize()) |
| 65 | + return size |
| 66 | + |
| 67 | + def serialize(self, result: bytearray) -> None: |
| 68 | + """Serialize the bitmap using a portable serialization format.""" |
| 69 | + # Write the number of bitmaps (8 bytes, little-endian) |
| 70 | + result += len(self.bitmaps).to_bytes(8, byteorder="little") |
| 71 | + for key, bitmap in enumerate(self.bitmaps): |
| 72 | + result += key.to_bytes(4, byteorder="little") |
| 73 | + result += bitmap.serialize() |
| 74 | + |
| 75 | + @classmethod |
| 76 | + def deserialize(cls, buffer: bytes) -> "RoaringPositionBitmap": |
| 77 | + """Deserializes a bitmap from a buffer, assuming the portable serialization format.""" |
| 78 | + offset = 0 |
| 79 | + bitmap_count, offset = cls.read_bitmap_count(buffer, offset) |
| 80 | + bitmaps: List[BitMap] = [] |
| 81 | + last_key = -1 |
| 82 | + |
| 83 | + for _ in range(bitmap_count): |
| 84 | + key, offset = cls.read_key(buffer, last_key, offset) |
| 85 | + # Fill gaps |
| 86 | + while last_key < key - 1: |
| 87 | + bitmaps.append(BitMap()) |
| 88 | + last_key += 1 |
| 89 | + bitmap, offset = cls.read_bitmap(buffer, offset) |
| 90 | + bitmaps.append(bitmap) |
| 91 | + last_key = key |
| 92 | + return cls(bitmaps) |
| 93 | + |
| 94 | + @staticmethod |
| 95 | + def key(pos: int) -> int: |
| 96 | + return (pos >> 32) & 0xFFFFFFFF |
| 97 | + |
| 98 | + @staticmethod |
| 99 | + def pos32_bits(pos: int) -> int: |
| 100 | + return pos & 0xFFFFFFFF |
| 101 | + |
| 102 | + @staticmethod |
| 103 | + def to_position(key: int, pos32_bits: int) -> int: |
| 104 | + return (key << 32) | (pos32_bits & 0xFFFFFFFF) |
| 105 | + |
| 106 | + @staticmethod |
| 107 | + def validate_position(pos: int) -> None: |
| 108 | + if not (0 <= pos <= RoaringPositionBitmap.MAX_POSITION): |
| 109 | + raise ValueError(f"Bitmap supports positions that are >= 0 and <= {RoaringPositionBitmap.MAX_POSITION}: {pos}") |
| 110 | + |
| 111 | + @staticmethod |
| 112 | + def read_bitmap_count(buffer: bytes, offset: int) -> Tuple[int, int]: |
| 113 | + bitmap_count = int.from_bytes(buffer[offset : offset + 8], byteorder="little") |
| 114 | + if not (0 <= bitmap_count <= 2**31 - 1): |
| 115 | + raise ValueError(f"Invalid bitmap count: {bitmap_count}") |
| 116 | + return bitmap_count, offset + 8 |
| 117 | + |
| 118 | + @staticmethod |
| 119 | + def read_key(buffer: bytes, last_key: int, offset: int) -> Tuple[int, int]: |
| 120 | + key = int.from_bytes(buffer[offset : offset + 4], byteorder="little") |
| 121 | + if key < 0: |
| 122 | + raise ValueError(f"Invalid unsigned key: {key}") |
| 123 | + if key > (2**31 - 2): |
| 124 | + raise ValueError(f"Key is too large: {key}") |
| 125 | + if key <= last_key: |
| 126 | + raise ValueError("Keys must be sorted in ascending order") |
| 127 | + return key, offset + 4 |
| 128 | + |
| 129 | + @staticmethod |
| 130 | + def read_bitmap(buffer: bytes, offset: int) -> Tuple[BitMap, int]: |
| 131 | + bitmap = BitMap.deserialize(buffer[offset:]) |
| 132 | + return bitmap, offset + len(bitmap.serialize()) |
| 133 | + |
| 134 | + def __iter__(self) -> Iterator[int]: |
| 135 | + """Return an iterator over all set positions in the bitmap.""" |
| 136 | + for key, bitmap in enumerate(self.bitmaps): |
| 137 | + for pos32 in bitmap: |
| 138 | + yield self.to_position(key, pos32) |
0 commit comments