Skip to content

Commit bc118cb

Browse files
committed
Karatsuba multiplication now works
1 parent f3bcc9b commit bc118cb

File tree

1 file changed

+24
-37
lines changed

1 file changed

+24
-37
lines changed

src/bigints.nim

Lines changed: 24 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ func initBigInt*(val: BigInt): BigInt =
6464
const
6565
zero = initBigInt(0)
6666
one = initBigInt(1)
67-
karatsubaTreshold = 5
67+
karatsubaTreshold = 10
6868

6969
func isZero(a: BigInt): bool {.inline.} =
7070
for i in countdown(a.limbs.high, 0):
@@ -418,35 +418,25 @@ func unsignedMultiplication(a: var BigInt, b, c: BigInt) {.inline.} =
418418
inc pos
419419
normalize(a)
420420

421-
func scalarMultiplication(a: var BigInt, b: uint32, c: BigInt) {.inline.} =
422-
# Based on unsignedMultiplication
421+
func scalarMultiplication(a: var BigInt, b: BigInt, c: uint32) {.inline.} =
422+
# always called with bl >= cl
423423
let
424-
cl = c.limbs.len
425-
a.limbs.setLen(1 + cl)
424+
bl = b.limbs.len
425+
a.limbs.setLen(bl + 1)
426426
var tmp = 0'u64
427427

428-
tmp += uint64(b) * uint64(c.limbs[0])
429-
a.limbs[1] = uint32(tmp and uint32.high)
430-
tmp = tmp shr 32 # carry
431-
432-
a.limbs[1] = uint32(tmp)
433-
434-
for j in 1 ..< cl:
435-
tmp = 0'u64
436-
tmp += uint64(a.limbs[j]) + uint64(b) * uint64(c.limbs[j])
437-
a.limbs[j] = uint32(tmp and uint32.high)
428+
for i in 0 ..< bl:
429+
tmp += uint64(b.limbs[i]) * uint64(c)
430+
a.limbs[i] = uint32(tmp and uint32.high)
438431
tmp = tmp shr 32
439-
var pos = j + 1
440-
while tmp > 0'u64:
441-
tmp += uint64(a.limbs[pos])
442-
a.limbs[pos] = uint32(tmp and uint32.high)
443-
tmp = tmp shr 32
444-
inc pos
432+
433+
a.limbs[bl] = uint32(tmp)
445434
normalize(a)
446435

447436
# forward declaration for use in `multiplication`
448-
func unsignedKaratsubaMultiplication(a: var BigInt, b, c: BigInt) {.inline.}
437+
func karatsubaMultiplication(a: var BigInt, b, c: BigInt) {.inline.}
449438
func `shl`*(x: BigInt, y: Natural): BigInt
439+
func `shr`*(x: BigInt, y: Natural): BigInt
450440

451441
func multiplication(a: var BigInt, b, c: BigInt) =
452442
# a = b * c
@@ -459,28 +449,27 @@ func multiplication(a: var BigInt, b, c: BigInt) =
459449

460450
if cl > bl:
461451
if bl <= karatsubaTreshold:
462-
unsignedKaratsubaMultiplication(a, c, b)
452+
karatsubaMultiplication(a, c, b)
463453
else:
464454
unsignedMultiplication(a, c, b)
465455
else:
466456
if cl <= karatsubaTreshold:
467-
unsignedKaratsubaMultiplication(a, b, c)
457+
karatsubaMultiplication(a, b, c)
468458
else:
469459
unsignedMultiplication(a, b, c)
470460
a.isNegative = b.isNegative xor c.isNegative
471461

472-
func `shr`*(x: BigInt, y: Natural): BigInt
473-
func unsignedKaratsubaMultiplication(a: var BigInt, b, c: BigInt) {.inline.} =
462+
func karatsubaMultiplication(a: var BigInt, b, c: BigInt) {.inline.} =
474463
let
475464
bl = b.limbs.len
476465
cl = c.limbs.len
477466
let n = max(bl, cl)
478467
if bl == 1:
479468
# base case : multiply the only limb with each limb of second term
480-
scalarMultiplication(a, b.limbs[0], c)
469+
scalarMultiplication(a, c, b.limbs[0])
481470
return
482471
if cl == 1:
483-
scalarMultiplication(a, c.limbs[0], b)
472+
scalarMultiplication(a, b, c.limbs[0])
484473
return
485474
if bl < karatsubaTreshold:
486475
if cl <= bl:
@@ -507,21 +496,19 @@ func unsignedKaratsubaMultiplication(a: var BigInt, b, c: BigInt) {.inline.} =
507496
# limit carry handling in opposition to the additive version
508497
var
509498
lowProduct, highProduct, A3, A4, A5, middleTerm: BigInt = zero
510-
unsignedKaratsubaMultiplication(lowProduct, low_b, low_c)
511-
unsignedKaratsubaMultiplication(highProduct, high_b, high_c)
499+
karatsubaMultiplication(lowProduct, low_b, low_c)
500+
karatsubaMultiplication(highProduct, high_b, high_c)
512501
A3 = low_b - high_b # Additive variant of Karatsuba
513-
A4 = high_c - low_c # would add them
502+
A4 = low_c - high_c # would add them
514503
if A4.limbs.len >= A3.limbs.len:
515504
multiplication(A5, abs(A4), abs(A3))
516505
else:
517506
multiplication(A5, abs(A3), abs(A4))
518507
middleTerm = lowProduct + highProduct + A5
519-
a = lowProduct + (middleTerm shr k) + (highProduct shr (2*k))
520-
# We could affect directly some of the bits of the result with slicing
521-
# a.limbs[0 .. k - 1] = lowProduct.limbs
522-
# But the following instructions would not be correct due to sign handling
523-
# a.limbs[k .. 2*k-1] = middleTerm.limbs
524-
# a.limbs[2*k .. 3*k-1] = highProduct.limbs
508+
a.limbs[0 .. k - 1] = lowProduct.limbs
509+
# a += (middleTerm shr k) + (highProduct shr (2*k))
510+
a.limbs[k .. 2*k-1] = middleTerm.limbs
511+
a.limbs[2*k .. 3*k-1] = highProduct.limbs
525512

526513

527514
func `*`*(a, b: BigInt): BigInt =

0 commit comments

Comments
 (0)