diff --git a/src/bigints.nim b/src/bigints.nim index 9ef5abb..41a504d 100644 --- a/src/bigints.nim +++ b/src/bigints.nim @@ -1198,3 +1198,49 @@ func powmod*(base, exponent, modulus: BigInt): BigInt = result = (result * basePow) mod modulus basePow = (basePow * basePow) mod modulus exponent = exponent shr 1 + +func getLimbAndBitPosFromBitPos(bit: Natural): tuple[bit, limb: Natural] {.inline.} = + (bit: Natural(bit and 31), limb: Natural(bit shr 5)) + +func setBit*(a: var BigInt; bit: Natural) = + ## Mutates `a`, with the bit at position `bit` set to 1. + runnableExamples: + var v = 0b0000_0011.initBigInt + v.setBit(5) + doAssert v == 0b0010_0011.initBigInt + + let (b, l) = getLimbAndBitPosFromBitPos(bit) + + if l >= a.limbs.len: + a.limbs.setLen(l + 1) + + a.limbs[l] = a.limbs[l] or (1'u32 shl b) + +func clearBit*(a: var BigInt; bit: Natural) = + ## Mutates `v`, with the bit at position `bit` set to 0. + runnableExamples: + var v = 0b0000_0011.initBigInt + v.clearBit(1) + doAssert v == 0b0000_0001.initBigInt + + let (b, l) = getLimbAndBitPosFromBitPos(bit) + + if l >= a.limbs.len: + return + + a.limbs[l] = a.limbs[l] and not (1'u32 shl b) + normalize(a) + +func testBit*(a: BigInt; bit: Natural): bool = + ## Returns true if the bit in `a` at positions `bit` is set to 1. + runnableExamples: + let v = 0b0000_1111.initBigInt + doAssert v.testBit(0) + doAssert not v.testBit(7) + + let (b, l) = getLimbAndBitPosFromBitPos(bit) + + if l >= a.limbs.len: + false + else: + (a.limbs[l] and (1'u32 shl b)) != 0 diff --git a/tests/tbigints.nim b/tests/tbigints.nim index 3b3a59d..95f6073 100644 --- a/tests/tbigints.nim +++ b/tests/tbigints.nim @@ -786,6 +786,88 @@ proc main() = doAssert pred(a, 3) == initBigInt(4) doAssert succ(a, 3) == initBigInt(10) + block: # setBit/clearBit/clearBit + var a = initBigInt(0) + for i in 0..256: + doAssert not a.testBit(i) + for i in 0..256: + a.setBit(i) + for j in 0..i: + doAssert a.testBit(j) + for j in (i + 1)..256: + doAssert not a.testBit(j) + for i in countDown(256, 0): + doAssert a.testBit(i) + for i in countDown(256, 0): + a.clearBit(i) + for j in 0..(i - 1): + doAssert a.testBit(j) + for j in i..256: + doAssert not a.testBit(j) + for i in countDown(256, 0): + doAssert not a.testBit(i) + for i in 1..256: + a.setBit(i) + doAssert a.testBit(i) + doAssert not a.testBit(i - 1) + doAssert not a.testBit(i + 1) + a.clearBit(i) + doAssert not a.testBit(i) + doAssert not a.testBit(i - 1) + doAssert not a.testBit(i + 1) + doAssert a == initBigInt(0) + + a = initBigInt(0) + a.setBit(0) + doAssert a == initBigInt(1) + doAssert a.testBit(0) + a.setBit(1) + doAssert a == initBigInt(3) + doAssert a.testBit(0) and a.testBit(1) + a.setBit(31) + doAssert a == initBigInt((1 shl 31) + 3) + doAssert a.testBit(0) and a.testBit(1) and a.testBit(31) + a.setBit(30) + doAssert a == initBigInt((1 shl 31) + (1 shl 30) + 3) + doAssert a.testBit(0) and a.testBit(1) and a.testBit(30) and a.testBit(31) + a.clearBit(31) + doAssert a == initBigInt((1 shl 30) + 3) + doAssert a.testBit(0) and a.testBit(1) and a.testBit(30) and (not a.testBit(31)) + a.clearBit(30) + doAssert a == initBigInt(3) + doAssert a.testBit(0) and a.testBit(1) and (not a.testBit(30)) and (not a.testBit(31)) + a.clearBit(1) + doAssert a == initBigInt(1) + doAssert a.testBit(0) and (not a.testBit(1)) and (not a.testBit(30)) and (not a.testBit(31)) + a.clearBit(0) + doAssert a == initBigInt(0) + + a = initBigInt(0) + a.setBit(63) + doAssert a == initBigInt(1'u64 shl 63'u64) + doAssert (not a.testBit(1)) and (not a.testBit(31)) and (not a.testBit(32)) and a.testBit(63) + a.setBit(1) + doAssert a == initBigInt((1'u64 shl 63'u64) + 2) + doAssert a.testBit(1) and (not a.testBit(31)) and (not a.testBit(32)) and a.testBit(63) + a.setBit(32) + doAssert a == initBigInt((1'u64 shl 63'u64) + (1'u64 shl 32'u64) + 2) + doAssert a.testBit(1) and (not a.testBit(31)) and a.testBit(32) and a.testBit(63) + a.setBit(31) + doAssert a == initBigInt((1'u64 shl 63'u64) + (1'u64 shl 32'u64) + (1'u64 shl 31'u64) + 2) + doAssert a.testBit(1) and a.testBit(31) and a.testBit(32) and a.testBit(63) + + a.clearBit(63) + doAssert a == initBigInt((1'u64 shl 32'u64) + (1'u64 shl 31'u64) + 2) + doAssert a.testBit(1) and a.testBit(31) and a.testBit(32) and (not a.testBit(63)) + a.clearBit(1) + doAssert a == initBigInt((1'u64 shl 32'u64) + (1'u64 shl 31'u64)) + doAssert (not a.testBit(1)) and a.testBit(31) and a.testBit(32) and (not a.testBit(63)) + a.clearBit(32) + doAssert a == initBigInt((1'u64 shl 31'u64)) + doAssert (not a.testBit(1)) and a.testBit(31) and (not a.testBit(32)) and (not a.testBit(63)) + a.clearBit(31) + doAssert a == initBigInt(0) + doAssert (not a.testBit(1)) and (not a.testBit(31)) and (not a.testBit(32)) and (not a.testBit(63)) static: main() main()