Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 50 additions & 4 deletions openleadr/messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,16 @@
from openleadr import errors
from datetime import datetime, timezone, timedelta
import os
from signxml.algorithms import SignatureMethod
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa, dsa, ec, ed25519, ed448

from openleadr import utils
from .preflight import preflight_message

import logging
logger = logging.getLogger('openleadr')

SIGNER = XMLSigner(method=methods.detached,
c14n_algorithm="http://www.w3.org/TR/2001/REC-xml-c14n-20010315")
SIGNER.namespaces['oadr'] = "http://openadr.org/oadr-2.0b/2012/07"
VERIFIER = XMLVerifier()

XML_SCHEMA_LOCATION = os.path.join(os.path.dirname(__file__), 'schema', 'oadr_20b.xsd')
Expand Down Expand Up @@ -62,6 +62,45 @@ def parse_message(data):
return message_type, message_payload


def load_private_key(key_data, passphrase=None):
"""
Load the key based on key data. Supports .pem and .der keys.

Returns a private key object.
"""
passphrase_bytes = passphrase.encode() if passphrase else None
try:
key = serialization.load_pem_private_key(key_data, passphrase_bytes)
except ValueError:
try:
key = serialization.load_der_private_key(key_data, passphrase_bytes)
except ValueError:
logger.warning("Could not load key: unknown key file format.")
return key


def get_signature_algorithm_from_private_key(key_data, passphrase=None, default_algorithm="rsa-sha256"):
"""
Derive a signature algorithm based on the private key type. Returns a string that can be used to lookup
a signature algorithm by fragment. Algorithms are chosen based on NIST recommendations.

SignXML supports only RSA-, DSA- and EC-based signature methods. As XMLSigner uses RSA_SHA256 as default
signature algorithm, a fragment that results in this algorithm is returned for unsupported keys.
"""
key = load_private_key(key_data, passphrase)
if isinstance(key, rsa.RSAPrivateKey):
return "rsa-sha256"
elif isinstance(key, dsa.DSAPrivateKey):
return "dsa-sha256"
elif isinstance(key, ec.EllipticCurvePrivateKey):
return "ecdsa-sha256"
elif isinstance(key, ed25519.Ed25519PrivateKey):
logger.warning("ED25519 keys are not supported")
elif isinstance(key, ed448.Ed448PrivateKey):
logger.warning("ED448 keys are not supported")
return default_algorithm


def create_message(message_type, cert=None, key=None, passphrase=None, disable_signature=False, **message_payload):
"""
Create and optionally sign an OpenADR message. Returns an XML string.
Expand All @@ -72,6 +111,12 @@ def create_message(message_type, cert=None, key=None, passphrase=None, disable_s
envelope = TEMPLATES.get_template('oadrPayload.xml')
if cert and key and not disable_signature:
tree = etree.fromstring(signed_object)
SIGNER = XMLSigner(
method=methods.detached,
c14n_algorithm="http://www.w3.org/TR/2001/REC-xml-c14n-20010315"
)
SIGNER.namespaces['oadr'] = "http://openadr.org/oadr-2.0b/2012/07"
SIGNER.sign_alg = SignatureMethod.from_fragment(get_signature_algorithm_from_private_key(key, passphrase))
signature_tree = SIGNER.sign(tree,
key=key,
cert=cert,
Expand All @@ -83,7 +128,8 @@ def create_message(message_type, cert=None, key=None, passphrase=None, disable_s
signature = None
msg = envelope.render(template=f'{message_type}',
signature=signature,
signed_object=signed_object)
signed_object=signed_object
)
logger.debug(f"Created message: {msg}")
return msg

Expand Down
32 changes: 32 additions & 0 deletions test/test_signature_algorithms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pytest
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa, dsa, ec, ed25519, ed448
from openleadr.messaging import get_signature_algorithm_from_private_key


test_keys = {
"rsa": rsa.generate_private_key(public_exponent=65537, key_size=2048),
"dsa": dsa.generate_private_key(key_size=2048),
"ec": ec.generate_private_key(ec.SECP256R1()),
"ed25519": ed25519.Ed25519PrivateKey.generate(),
"ed448": ed448.Ed448PrivateKey.generate()
}


@pytest.mark.parametrize("key_type, expected_alg", [
("rsa", "rsa-sha256"),
("dsa", "dsa-sha256"),
("ec", "ecdsa-sha256"),
("ed25519", "rsa-sha256"),
("ed448", "rsa-sha256"),
])
def test_key_type_sign_alg_match(key_type, expected_alg):
test_key = test_keys[key_type]
key_encoding = serialization.Encoding.PEM
key_format = serialization.PrivateFormat.PKCS8
key_encryption_alg = serialization.NoEncryption()
key_bytes = test_key.private_bytes(key_encoding, key_format, key_encryption_alg)

detected_alg = get_signature_algorithm_from_private_key(key_bytes)

assert detected_alg == expected_alg