Skip to content

Commit e37d06e

Browse files
committed
Encoding and decoding byte byte representations of Edwards points
1 parent 9557d97 commit e37d06e

File tree

2 files changed

+205
-2
lines changed

2 files changed

+205
-2
lines changed

src/ecdsa/ellipticcurve.py

Lines changed: 101 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949

5050
from six import python_2_unicode_compatible
5151
from . import numbertheory
52-
from ._compat import normalise_bytes
52+
from ._compat import normalise_bytes, int_to_bytes, bit_length, bytes_to_int
5353
from .errors import MalformedPointError
5454
from .util import orderlen, string_to_number, number_to_string
5555

@@ -278,6 +278,41 @@ def _from_hybrid(cls, data, raw_encoding_length, validate_encoding):
278278

279279
return x, y
280280

281+
@classmethod
282+
def _from_edwards(cls, curve, data):
283+
"""Decode a point on an Edwards curve."""
284+
data = bytearray(data)
285+
p = curve.p()
286+
# add 1 for the sign bit and then round up
287+
exp_len = (bit_length(p) + 1 + 7) // 8
288+
if len(data) != exp_len:
289+
raise MalformedPointError("Point length doesn't match the curve.")
290+
x_0 = (data[-1] & 0x80) >> 7
291+
292+
data[-1] &= 0x80 - 1
293+
294+
y = bytes_to_int(data, "little")
295+
if GMPY:
296+
y = mpz(y)
297+
298+
x2 = (
299+
(y * y - 1)
300+
* numbertheory.inverse_mod(curve.d() * y * y - curve.a(), p)
301+
% p
302+
)
303+
304+
try:
305+
x = numbertheory.square_root_mod_prime(x2, p)
306+
except numbertheory.SquareRootError as e:
307+
raise MalformedPointError(
308+
"Encoding does not correspond to a point on curve", e
309+
)
310+
311+
if x % 2 != x_0:
312+
x = -x % p
313+
314+
return x, y
315+
281316
@classmethod
282317
def from_bytes(
283318
cls, curve, data, validate_encoding=True, valid_encodings=None
@@ -325,6 +360,10 @@ def from_bytes(
325360
"supported."
326361
)
327362
data = normalise_bytes(data)
363+
364+
if isinstance(curve, CurveEdTw):
365+
return cls._from_edwards(curve, data)
366+
328367
key_len = len(data)
329368
raw_encoding_length = 2 * orderlen(curve.p())
330369
if key_len == raw_encoding_length and "raw" in valid_encodings:
@@ -381,6 +420,18 @@ def _hybrid_encode(self):
381420
return b"\x07" + raw_enc
382421
return b"\x06" + raw_enc
383422

423+
def _edwards_encode(self):
424+
"""Encode the point according to RFC8032 encoding."""
425+
self.scale()
426+
x, y, p = self.x(), self.y(), self.curve().p()
427+
428+
# add 1 for the sign bit and then round up
429+
enc_len = (bit_length(p) + 1 + 7) // 8
430+
y_str = int_to_bytes(y, enc_len, "little")
431+
if x % 2:
432+
y_str[-1] |= 0x80
433+
return y_str
434+
384435
def to_bytes(self, encoding="raw"):
385436
"""
386437
Convert the point to a byte string.
@@ -389,11 +440,17 @@ def to_bytes(self, encoding="raw"):
389440
by `encoding="raw"`. It can also output points in :term:`uncompressed`,
390441
:term:`compressed`, and :term:`hybrid` formats.
391442
443+
For points on Edwards curves `encoding` is ignored and only the
444+
encoding defined in RFC 8032 is supported.
445+
392446
:return: :term:`raw encoding` of a public on the curve
393447
:rtype: bytes
394448
"""
395449
assert encoding in ("raw", "uncompressed", "compressed", "hybrid")
396-
if encoding == "raw":
450+
curve = self.curve()
451+
if isinstance(curve, CurveEdTw):
452+
return self._edwards_encode()
453+
elif encoding == "raw":
397454
return self._raw_encode()
398455
elif encoding == "uncompressed":
399456
return b"\x04" + self._raw_encode()
@@ -1219,6 +1276,48 @@ def __init__(self, curve, x, y, z, t, order=None):
12191276
self.__coords = (x, y, z, t)
12201277
self.__order = order
12211278

1279+
@classmethod
1280+
def from_bytes(
1281+
cls,
1282+
curve,
1283+
data,
1284+
validate_encoding=None,
1285+
valid_encodings=None,
1286+
order=None,
1287+
generator=False,
1288+
):
1289+
"""
1290+
Initialise the object from byte encoding of a point.
1291+
1292+
`validate_encoding` and `valid_encodings` are provided for
1293+
compatibility with Weierstrass curves, they are ignored for Edwards
1294+
points.
1295+
1296+
:param data: single point encoding of the public key
1297+
:type data: :term:`bytes-like object`
1298+
:param curve: the curve on which the public key is expected to lay
1299+
:type curve: ecdsa.ellipticcurve.CurveEdTw
1300+
:param None validate_encoding: Ignored, encoding is always validated
1301+
:param None valid_encodings: Ignored, there is just one encoding
1302+
supported
1303+
:param int order: the point order, must be non zero when using
1304+
generator=True
1305+
:param bool generator: Ignored, may be used in the future
1306+
to precompute point multiplication table.
1307+
1308+
:raises MalformedPointError: if the public point does not lay on the
1309+
curve or the encoding is invalid
1310+
1311+
:return: Initialised point on an Edwards curve
1312+
:rtype: PointEdwards
1313+
"""
1314+
coord_x, coord_y = super(PointEdwards, cls).from_bytes(
1315+
curve, data, validate_encoding, valid_encodings
1316+
)
1317+
return PointEdwards(
1318+
curve, coord_x, coord_y, 1, coord_x * coord_y, order
1319+
)
1320+
12221321
def x(self):
12231322
"""Return affine x coordinate."""
12241323
X1, _, Z1, _ = self.__coords

src/ecdsa/test_eddsa.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import unittest2 as unittest
55
except ImportError:
66
import unittest
7+
from hypothesis import given, settings, example
8+
import hypothesis.strategies as st
79
from .ellipticcurve import PointEdwards, INFINITY, CurveEdTw
810
from .eddsa import (
911
generator_ed25519,
@@ -12,6 +14,7 @@
1214
curve_ed448,
1315
)
1416
from .ecdsa import generator_256, curve_256
17+
from .errors import MalformedPointError
1518

1619

1720
def test_ed25519_curve_compare():
@@ -405,3 +408,104 @@ def test_ed448_add_and_mul_equivalence():
405408

406409
assert g + g == g * 2
407410
assert g + g + g == g * 3
411+
412+
413+
def test_ed25519_encode():
414+
g = generator_ed25519
415+
g_bytes = g.to_bytes()
416+
assert len(g_bytes) == 32
417+
exp_bytes = (
418+
b"\x58\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66"
419+
b"\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66"
420+
)
421+
assert g_bytes == exp_bytes
422+
423+
424+
def test_ed25519_decode():
425+
exp_bytes = (
426+
b"\x58\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66"
427+
b"\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66"
428+
)
429+
a = PointEdwards.from_bytes(curve_ed25519, exp_bytes)
430+
431+
assert a == generator_ed25519
432+
433+
434+
class TestEdwardsMalformed(unittest.TestCase):
435+
def test_invalid_point(self):
436+
exp_bytes = (
437+
b"\x78\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66"
438+
b"\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66"
439+
)
440+
with self.assertRaises(MalformedPointError):
441+
PointEdwards.from_bytes(curve_ed25519, exp_bytes)
442+
443+
def test_invalid_length(self):
444+
exp_bytes = (
445+
b"\x58\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66"
446+
b"\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66"
447+
b"\x66"
448+
)
449+
with self.assertRaises(MalformedPointError) as e:
450+
PointEdwards.from_bytes(curve_ed25519, exp_bytes)
451+
452+
self.assertIn("length", str(e.exception))
453+
454+
def test_ed448_invalid(self):
455+
exp_bytes = b"\xff" * 57
456+
with self.assertRaises(MalformedPointError):
457+
PointEdwards.from_bytes(curve_ed448, exp_bytes)
458+
459+
460+
def test_ed448_encode():
461+
g = generator_ed448
462+
g_bytes = g.to_bytes()
463+
assert len(g_bytes) == 57
464+
exp_bytes = (
465+
b"\x14\xfa\x30\xf2\x5b\x79\x08\x98\xad\xc8\xd7\x4e\x2c\x13\xbd"
466+
b"\xfd\xc4\x39\x7c\xe6\x1c\xff\xd3\x3a\xd7\xc2\xa0\x05\x1e\x9c"
467+
b"\x78\x87\x40\x98\xa3\x6c\x73\x73\xea\x4b\x62\xc7\xc9\x56\x37"
468+
b"\x20\x76\x88\x24\xbc\xb6\x6e\x71\x46\x3f\x69\x00"
469+
)
470+
assert g_bytes == exp_bytes
471+
472+
473+
def test_ed448_decode():
474+
exp_bytes = (
475+
b"\x14\xfa\x30\xf2\x5b\x79\x08\x98\xad\xc8\xd7\x4e\x2c\x13\xbd"
476+
b"\xfd\xc4\x39\x7c\xe6\x1c\xff\xd3\x3a\xd7\xc2\xa0\x05\x1e\x9c"
477+
b"\x78\x87\x40\x98\xa3\x6c\x73\x73\xea\x4b\x62\xc7\xc9\x56\x37"
478+
b"\x20\x76\x88\x24\xbc\xb6\x6e\x71\x46\x3f\x69\x00"
479+
)
480+
481+
a = PointEdwards.from_bytes(curve_ed448, exp_bytes)
482+
483+
assert a == generator_ed448
484+
485+
486+
HYP_SETTINGS = dict()
487+
HYP_SETTINGS["max_examples"] = 10
488+
489+
490+
@settings(**HYP_SETTINGS)
491+
@example(1)
492+
@example(5) # smallest multiple that requires changing sign of x
493+
@given(st.integers(min_value=1, max_value=int(generator_ed25519.order() - 1)))
494+
def test_ed25519_encode_decode(multiple):
495+
a = generator_ed25519 * multiple
496+
497+
b = PointEdwards.from_bytes(curve_ed25519, a.to_bytes())
498+
499+
assert a == b
500+
501+
502+
@settings(**HYP_SETTINGS)
503+
@example(1)
504+
@example(2) # smallest multiple that requires changing the sign of x
505+
@given(st.integers(min_value=1, max_value=int(generator_ed448.order() - 1)))
506+
def test_ed448_encode_decode(multiple):
507+
a = generator_ed448 * multiple
508+
509+
b = PointEdwards.from_bytes(curve_ed448, a.to_bytes())
510+
511+
assert a == b

0 commit comments

Comments
 (0)