Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 151 additions & 24 deletions crypto_condor/primitives/AES.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import importlib
import json
import logging
import math
import subprocess
import sys
import zipfile
Expand Down Expand Up @@ -197,7 +198,24 @@ def _get_aes_lib() -> tuple[cffi.FFI | None, _cffi_backend.Lib | None]:
void AES_CFB_decrypt_buffer(struct AES_ctx *ctx,
uint8_t *buffer, size_t length,
size_t segment_size);
void AES_CTR_xcrypt_buffer(struct AES_ctx *ctx, uint8_t *buffer, size_t length);
void AES_CTR_xcrypt_buffer(struct AES_ctx *ctx, uint8_t *buffer,
size_t length);
void AES_KW_encrypt_buffer(const struct AES_ctx *ctx, uint8_t *buffer,
size_t length);
void AES_KW_encrypt_inv_buffer(const struct AES_ctx *ctx, uint8_t *buffer,
size_t length);
void AES_KWP_encrypt_buffer(const struct AES_ctx *ctx, uint8_t *buffer,
size_t length);
void AES_KWP_encrypt_inv_buffer(const struct AES_ctx *ctx, uint8_t *buffer,
size_t length);
int AES_KW_decrypt_buffer(const struct AES_ctx *ctx, uint8_t *buffer,
size_t length);
int AES_KW_decrypt_inv_buffer(const struct AES_ctx *ctx, uint8_t *buffer,
size_t length);
int AES_KWP_decrypt_buffer(const struct AES_ctx *ctx, uint8_t *buffer,
size_t length);
int AES_KWP_decrypt_inv_buffer(const struct AES_ctx *ctx, uint8_t *buffer,
size_t length);
"""
)

Expand Down Expand Up @@ -241,7 +259,7 @@ class Encrypt(Protocol):
- CCM or GCM
"""

# ECB
# ECB / KW / KW_INV / KWP / KWP_INV
@overload
def __call__(self, key: bytes, plaintext: bytes) -> bytes: ...

Expand Down Expand Up @@ -301,7 +319,7 @@ class Decrypt(Protocol):
- CCM or GCM
"""

# ECB
# ECB / KW / KW_INV / KWP / KWP_INV
@overload
def __call__(self, key: bytes, ciphertext: bytes) -> bytes: ...

Expand Down Expand Up @@ -461,9 +479,11 @@ class ParsingError(Exception):
# ----------------------------- AES functions -----------------------------------------


# ECB
# ECB / KW / KW_INV / KWP / KWP_INV
@overload
def _encrypt(mode: Literal[Mode.ECB], key: bytes, plaintext: bytes) -> bytes: ...
def _encrypt(mode: Literal[Mode.ECB, Mode.KW, Mode.KW_INV, Mode.KWP, Mode.KWP_INV],
key: bytes,
plaintext: bytes) -> bytes: ...


# CBC / CTR
Expand Down Expand Up @@ -512,7 +532,7 @@ def _encrypt(
mac_len: (CCM/GCM) The length of the authentication tag.

Returns:
(ECB/CBC/CTR/CFB) The resulting ciphertext.
(ECB/KW/KW_INV/KWP/KPW_INV/CBC/CTR/CFB) The resulting ciphertext.

(CCM/GCM) A (ciphertext, tag) tuple.

Expand Down Expand Up @@ -551,13 +571,50 @@ def _encrypt(

ctx = ffi.new("struct AES_ctx *")
ctx_key = ffi.new(f"uint8_t[{len(key)}]", key)
buf = ffi.new(f"uint8_t[{len(plaintext)}]", plaintext)

# Deal with ECB first, so we can initialize a single IV if it's not ECB.
# Deal with ECB/KW/KW_INV/KWP/KWP_INV first, so we can initialize a single IV
# if it's not ECB or KW/KW_INV/KWP/KWP_INV.
if mode == Mode.ECB:
buf = ffi.new(f"uint8_t[{len(plaintext)}]", plaintext)
lib.AES_init_ctx(ctx, ctx_key, len(key)) # type: ignore[attr-defined]
lib.AES_ECB_encrypt_buffer(ctx, buf, len(plaintext)) # type: ignore[attr-defined]
return bytes(buf)
elif mode == Mode.KW:
buf = ffi.new(f"uint8_t[{len(plaintext) + 8}]", plaintext)
lib.AES_init_ctx(ctx, ctx_key, len(key)) # type: ignore[attr-defined]
lib.AES_KW_encrypt_buffer(ctx, buf, len(plaintext) + 8) # type: ignore[attr-defined]
return bytes(buf)
elif mode == Mode.KW_INV:
buf = ffi.new(f"uint8_t[{len(plaintext) + 8}]", plaintext) # type: ignore[attr-defined]
lib.AES_init_ctx(ctx, ctx_key, len(key)) # type: ignore[attr-defined]
lib.AES_KW_encrypt_inv_buffer(ctx, buf, len(plaintext) + 8) # type: ignore[attr-defined]
return bytes(buf)
elif mode == Mode.KWP:
# Compute the buffer size needed (plaintext + padding + IV space)
if len(plaintext) == 0:
# 0-byte plaintext needs 8 bytes padding + 8 bytes IV = 16 total
buffer_size = 16
else:
padding_len = 8 * math.ceil(len(plaintext) / 8) - len(plaintext)
buffer_size = len(plaintext) + padding_len + 8
buf = ffi.new(f"uint8_t[{buffer_size}]", plaintext)
lib.AES_init_ctx(ctx, ctx_key, len(key)) # type: ignore[attr-defined]
lib.AES_KWP_encrypt_buffer(ctx, buf, len(plaintext)) # type: ignore[attr-defined]
return bytes(buf)
elif mode == Mode.KWP_INV:
# Compute the buffer size needed (plaintext + padding + IV space)
if len(plaintext) == 0:
# 0-byte plaintext needs 8 bytes padding + 8 bytes IV = 16 total
buffer_size = 16
else:
padding_len = 8 * math.ceil(len(plaintext) / 8) - len(plaintext)
buffer_size = len(plaintext) + padding_len + 8
buf = ffi.new(f"uint8_t[{buffer_size}]", plaintext)
lib.AES_init_ctx(ctx, ctx_key, len(key)) # type: ignore[attr-defined]
lib.AES_KWP_encrypt_inv_buffer(ctx, buf, len(plaintext)) # type: ignore[attr-defined]
return bytes(buf)
else:
buf = ffi.new(f"uint8_t[{len(plaintext)}]", plaintext)

if iv is None:
raise ValueError(f"{str(mode)} mode requires an IV")
Expand All @@ -581,8 +638,11 @@ def _encrypt(
return bytes(buf)


# ECB / KW / KW_INV / KWP / KWP_INV
@overload
def _decrypt(mode: Literal[Mode.ECB], key: bytes, ciphertext: bytes) -> bytes: ...
def _decrypt(mode: Literal[Mode.ECB, Mode.KW, Mode.KW_INV, Mode.KWP, Mode.KWP_INV],
key: bytes,
ciphertext: bytes) -> bytes: ...


# CBC / CTR
Expand Down Expand Up @@ -619,7 +679,7 @@ def _decrypt(
aad: bytes | None = None,
mac: bytes | None = None,
mac_len: int = 0,
) -> bytes | PlaintextAndBool:
) -> bytes | PlaintextAndBool | None:
"""Decrypts with AES.

Args:
Expand Down Expand Up @@ -686,14 +746,49 @@ def _decrypt(

ctx = ffi.new("struct AES_ctx *")
ctx_key = ffi.new(f"uint8_t[{len(key)}]", key)
buf = ffi.new(f"uint8_t[{len(ciphertext)}]", ciphertext)

# Deal with ECB first, so we can initialize a single IV if it's not ECB.
# Deal with ECB and KW/KW_INV/KWP/KWP_INV first,
# So we can initialize a single IV if it's not those modes.
if mode == Mode.ECB:
buf = ffi.new(f"uint8_t[{len(ciphertext)}]", ciphertext)
lib.AES_init_ctx(ctx, ctx_key, len(key)) # type: ignore[attr-defined]
lib.AES_ECB_decrypt_buffer(ctx, buf, len(ciphertext)) # type: ignore[attr-defined]
return bytes(buf)

elif mode == Mode.KW:
buf = ffi.new(f"uint8_t[{len(ciphertext)}]", ciphertext)
lib.AES_init_ctx(ctx, ctx_key, len(key)) # type: ignore[attr-defined]
ret_val = lib.AES_KW_decrypt_buffer(ctx, buf, len(ciphertext)) # type: ignore[attr-defined]
if ret_val == 0:
return None
return bytes(buf)[:len(ciphertext) - 8]
elif mode == Mode.KW_INV:
buf = ffi.new(f"uint8_t[{len(ciphertext)}]", ciphertext)
lib.AES_init_ctx(ctx, ctx_key, len(key)) # type: ignore[attr-defined]
ret_val = lib.AES_KW_decrypt_inv_buffer(ctx, buf, len(ciphertext)) # type: ignore[attr-defined]
if ret_val == 0:
return None
return bytes(buf)[:len(ciphertext) - 8]
elif mode == Mode.KWP:
# Ensure minimum buffer size for KWP operations
buffer_size = max(len(ciphertext), 16)
buf = ffi.new(f"uint8_t[{buffer_size}]", ciphertext)
lib.AES_init_ctx(ctx, ctx_key, len(key)) # type: ignore[attr-defined]
ret_val = lib.AES_KWP_decrypt_buffer(ctx, buf, len(ciphertext)) # type: ignore[attr-defined]
if ret_val == 0:
return None
return bytes(buf)[:ret_val]
elif mode == Mode.KWP_INV:
# Ensure minimum buffer size for KWP operations
buffer_size = max(len(ciphertext), 16)
buf = ffi.new(f"uint8_t[{buffer_size}]", ciphertext)
lib.AES_init_ctx(ctx, ctx_key, len(key)) # type: ignore[attr-defined]
ret_val = lib.AES_KWP_decrypt_inv_buffer(ctx, buf, len(ciphertext)) # type: ignore[attr-defined]
if ret_val == 0:
return None
return bytes(buf)[:ret_val]
else:
buf = ffi.new(f"uint8_t[{len(ciphertext)}]", ciphertext)

if iv is None:
raise ValueError(f"{str(mode)} mode requires an IV")
ctx_iv = ffi.new("uint8_t[16]", iv)
Expand Down Expand Up @@ -902,7 +997,7 @@ def _load_vectors(
def _try_one_enc(enc: Encrypt, mode: Mode, test: AesTest) -> tuple[bytes, bytes | None]:
ret_tag: bytes | None = None
match mode:
case Mode.ECB:
case Mode.ECB | Mode.KW | Mode.KW_INV | Mode.KWP | Mode.KWP_INV:
ret_ct = enc(test.key, test.pt)
case Mode.CBC | Mode.CTR | Mode.CFB | Mode.CFB8 | Mode.CFB128:
ret_ct = enc(test.key, test.pt, iv=test.iv)
Expand All @@ -924,7 +1019,7 @@ def _try_one_dec(
ret_valid_tag: bool | None

match mode:
case Mode.ECB:
case Mode.ECB | Mode.KW | Mode.KW_INV | Mode.KWP | Mode.KWP_INV:
ret_pt = dec(test.key, test.ct)
ret_valid_tag = None
case Mode.CBC | Mode.CTR | Mode.CFB | Mode.CFB8 | Mode.CFB128 | Mode.CBC_PKCS7:
Expand Down Expand Up @@ -1258,13 +1353,16 @@ def test_decrypt(
def _test_output_enc(line: str, mode: Mode):
match line.rstrip().split("/"):
case [_k, _p, _c]:
if mode != Mode.ECB:
raise ParsingError("Got 3 values but the mode is not ECB")
if mode != Mode.ECB and mode != Mode.KW and\
mode != Mode.KW_INV and mode != Mode.KWP and mode != Mode.KWP_INV:
raise ParsingError("Got 3 values but the mode is not ECB or KW")
key, pt, ct = map(bytes.fromhex, (_k, _p, _c))
ref_ct = _encrypt(mode, key, pt)
return EncData(key, pt, ct, None, None, None, ref_ct, None)
case [_k, _p, _c, _i]:
if mode == Mode.ECB or mode == Mode.CCM or mode == Mode.GCM:
if mode == Mode.ECB or mode == Mode.KW or mode == Mode.KW_INV or\
mode == Mode.KWP or mode == Mode.KWP_INV or mode == Mode.CCM or\
mode == Mode.GCM:
raise ParsingError(f"Got 4 values but the mode is {str(mode)}")
key, pt, ct, iv = map(bytes.fromhex, (_k, _p, _c, _i))
ref_ct = _encrypt(mode, key, pt, iv=iv)
Expand Down Expand Up @@ -1383,13 +1481,17 @@ def test_output_encrypt(filename: str, mode: Mode) -> ResultsDict:
def _test_output_dec(line: str, mode: Mode):
match line.rstrip().split("/"):
case [_k, _c, _p]:
if mode != Mode.ECB:
if mode != Mode.ECB and mode != Mode.KW and\
mode != Mode.KW_INV and mode != Mode.KWP and\
mode != Mode.KWP_INV:
raise ParsingError("Got 3 values but the mode is not ECB")
key, ct, pt = map(bytes.fromhex, (_k, _c, _p))
ref_pt = _decrypt(mode, key, ct)
return DecData(key, ct, pt, None, None, None, ref_pt, None)
case [_k, _c, _p, _i]:
if mode == Mode.ECB or mode == Mode.CCM or mode == Mode.GCM:
if mode == Mode.ECB or mode == Mode.KW or mode == Mode.KW_INV or\
mode == Mode.KWP or mode == Mode.KWP_INV or mode == Mode.CCM or\
mode == Mode.GCM:
raise ParsingError(f"Got 4 values but the mode is {str(mode)}")
key, ct, pt, iv = map(bytes.fromhex, (_k, _c, _p, _i))
ref_pt = _decrypt(mode, key, ct, iv=iv)
Expand Down Expand Up @@ -1724,6 +1826,10 @@ def _enc(key: bytes, plaintext: bytes, iv: bytes = b"") -> bytes:
pad_len = 16 - (len(plaintext) % 16)
ct_len = len(plaintext) + pad_len
c_ct = ffi.new(f"uint8_t[{ct_len}]")
elif mode in {Mode.KW, Mode.KW_INV, Mode.KWP, Mode.KWP_INV}:
# Key Wrap modes always add 8 bytes to the output
ct_len = len(plaintext) + 8
c_ct = ffi.new(f"uint8_t[{ct_len}]")
else:
# ct_len = ((len(plaintext) + 15) // 16) * 16
ct_len = len(plaintext)
Expand Down Expand Up @@ -1771,10 +1877,31 @@ def _dec(key: bytes, ciphertext: bytes, iv: bytes = b"") -> bytes:
c_key = ffi.new("uint8_t[]", key)
c_ct = ffi.new("uint8_t[]", ciphertext)
c_iv = ffi.new("uint8_t[]", iv)
c_pt = ffi.new(f"uint8_t[{len(ciphertext)}]")
rc = dec(
c_pt, len(ciphertext), c_ct, len(ciphertext), c_key, len(key), c_iv, len(iv)
)
if mode in {Mode.KWP, Mode.KWP_INV}:
# For KWP modes, the plaintext length is variable
# And determined by the C function
# Allocate a buffer as large as the ciphertext
# The function returns the actual length.
c_pt = ffi.new(f"uint8_t[{len(ciphertext)}]")
rc = dec(
c_pt, len(ciphertext), c_ct, len(ciphertext),
c_key, len(key), c_iv, len(iv)
)
elif mode in {Mode.KW, Mode.KW_INV}:
# For KW modes (without padding)
# The output plaintext is always ciphertext length minus 8 bytes.
pt_len = len(ciphertext) - 8
c_pt = ffi.new(f"uint8_t[{pt_len}]")
rc = dec(
c_pt, pt_len, c_ct, len(ciphertext),
c_key, len(key), c_iv, len(iv)
)
else:
c_pt = ffi.new(f"uint8_t[{len(ciphertext)}]")
rc = dec(
c_pt, len(ciphertext), c_ct, len(ciphertext),
c_key, len(key), c_iv, len(iv)
)
if rc >= 0:
return bytes(c_pt)[:rc]
else:
Expand Down
2 changes: 2 additions & 0 deletions crypto_condor/primitives/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ all: libs

# Installs libs through the installation function of their module.
libs:
# Force recompilation of AES C code
zip -j _aes/AES.zip _aes/aes.c _aes/aes.h
python AES.py
python TestU01.py
python MLDSA.py
Expand Down
Loading