diff --git a/openleadr/messaging.py b/openleadr/messaging.py index 5663fe1..c8d4c59 100644 --- a/openleadr/messaging.py +++ b/openleadr/messaging.py @@ -23,6 +23,9 @@ from openleadr import errors from datetime import datetime, timezone, timedelta import os +from signxml.algorithms import SignatureMethod +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa, dsa, ec, ed25519, ed448 from openleadr import utils from .preflight import preflight_message @@ -30,9 +33,6 @@ import logging logger = logging.getLogger('openleadr') -SIGNER = XMLSigner(method=methods.detached, - c14n_algorithm="http://www.w3.org/TR/2001/REC-xml-c14n-20010315") -SIGNER.namespaces['oadr'] = "http://openadr.org/oadr-2.0b/2012/07" VERIFIER = XMLVerifier() XML_SCHEMA_LOCATION = os.path.join(os.path.dirname(__file__), 'schema', 'oadr_20b.xsd') @@ -62,6 +62,45 @@ def parse_message(data): return message_type, message_payload +def load_private_key(key_data, passphrase=None): + """ + Load the key based on key data. Supports .pem and .der keys. + + Returns a private key object. + """ + passphrase_bytes = passphrase.encode() if passphrase else None + try: + key = serialization.load_pem_private_key(key_data, passphrase_bytes) + except ValueError: + try: + key = serialization.load_der_private_key(key_data, passphrase_bytes) + except ValueError: + logger.warning("Could not load key: unknown key file format.") + return key + + +def get_signature_algorithm_from_private_key(key_data, passphrase=None, default_algorithm="rsa-sha256"): + """ + Derive a signature algorithm based on the private key type. Returns a string that can be used to lookup + a signature algorithm by fragment. Algorithms are chosen based on NIST recommendations. + + SignXML supports only RSA-, DSA- and EC-based signature methods. As XMLSigner uses RSA_SHA256 as default + signature algorithm, a fragment that results in this algorithm is returned for unsupported keys. + """ + key = load_private_key(key_data, passphrase) + if isinstance(key, rsa.RSAPrivateKey): + return "rsa-sha256" + elif isinstance(key, dsa.DSAPrivateKey): + return "dsa-sha256" + elif isinstance(key, ec.EllipticCurvePrivateKey): + return "ecdsa-sha256" + elif isinstance(key, ed25519.Ed25519PrivateKey): + logger.warning("ED25519 keys are not supported") + elif isinstance(key, ed448.Ed448PrivateKey): + logger.warning("ED448 keys are not supported") + return default_algorithm + + def create_message(message_type, cert=None, key=None, passphrase=None, disable_signature=False, **message_payload): """ Create and optionally sign an OpenADR message. Returns an XML string. @@ -72,6 +111,12 @@ def create_message(message_type, cert=None, key=None, passphrase=None, disable_s envelope = TEMPLATES.get_template('oadrPayload.xml') if cert and key and not disable_signature: tree = etree.fromstring(signed_object) + SIGNER = XMLSigner( + method=methods.detached, + c14n_algorithm="http://www.w3.org/TR/2001/REC-xml-c14n-20010315" + ) + SIGNER.namespaces['oadr'] = "http://openadr.org/oadr-2.0b/2012/07" + SIGNER.sign_alg = SignatureMethod.from_fragment(get_signature_algorithm_from_private_key(key, passphrase)) signature_tree = SIGNER.sign(tree, key=key, cert=cert, @@ -83,7 +128,8 @@ def create_message(message_type, cert=None, key=None, passphrase=None, disable_s signature = None msg = envelope.render(template=f'{message_type}', signature=signature, - signed_object=signed_object) + signed_object=signed_object + ) logger.debug(f"Created message: {msg}") return msg diff --git a/test/test_signature_algorithms.py b/test/test_signature_algorithms.py new file mode 100644 index 0000000..dea53d7 --- /dev/null +++ b/test/test_signature_algorithms.py @@ -0,0 +1,32 @@ +import pytest +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa, dsa, ec, ed25519, ed448 +from openleadr.messaging import get_signature_algorithm_from_private_key + + +test_keys = { + "rsa": rsa.generate_private_key(public_exponent=65537, key_size=2048), + "dsa": dsa.generate_private_key(key_size=2048), + "ec": ec.generate_private_key(ec.SECP256R1()), + "ed25519": ed25519.Ed25519PrivateKey.generate(), + "ed448": ed448.Ed448PrivateKey.generate() +} + + +@pytest.mark.parametrize("key_type, expected_alg", [ + ("rsa", "rsa-sha256"), + ("dsa", "dsa-sha256"), + ("ec", "ecdsa-sha256"), + ("ed25519", "rsa-sha256"), + ("ed448", "rsa-sha256"), +]) +def test_key_type_sign_alg_match(key_type, expected_alg): + test_key = test_keys[key_type] + key_encoding = serialization.Encoding.PEM + key_format = serialization.PrivateFormat.PKCS8 + key_encryption_alg = serialization.NoEncryption() + key_bytes = test_key.private_bytes(key_encoding, key_format, key_encryption_alg) + + detected_alg = get_signature_algorithm_from_private_key(key_bytes) + + assert detected_alg == expected_alg