Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions ml-dsa/benches/ml_dsa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ fn criterion_benchmark(c: &mut Criterion) {
let m: B32 = rand(&mut rng);
let ctx: B32 = rand(&mut rng);

let kp = MlDsa65::key_gen_internal(&xi);
let kp = MlDsa65::from_seed(&xi);
let sk = kp.signing_key();
let vk = kp.verifying_key();
let sig = sk.sign_deterministic(&m, &ctx).unwrap();
Expand All @@ -27,7 +27,7 @@ fn criterion_benchmark(c: &mut Criterion) {
// Key generation
c.bench_function("keygen", |b| {
b.iter(|| {
let kp = MlDsa65::key_gen_internal(&xi);
let kp = MlDsa65::from_seed(&xi);
let _sk_bytes = kp.signing_key().encode();
let _vk_bytes = kp.verifying_key().encode();
})
Expand All @@ -53,7 +53,7 @@ fn criterion_benchmark(c: &mut Criterion) {
// Round trip
c.bench_function("round_trip", |b| {
b.iter(|| {
let kp = MlDsa65::key_gen_internal(&xi);
let kp = MlDsa65::from_seed(&xi);
let sig = kp.signing_key().sign_deterministic(&m, &ctx).unwrap();
let _ver = kp.verifying_key().verify_with_context(&m, &ctx, &sig);
})
Expand Down
55 changes: 43 additions & 12 deletions ml-dsa/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ where

let seed = Array::try_from(private_key_info.private_key.as_bytes())
.map_err(|_| pkcs8::Error::KeyMalformed)?;
Ok(P::key_gen_internal(&seed))
Ok(P::from_seed(&seed))
}
}

Expand Down Expand Up @@ -352,6 +352,16 @@ impl<P: MlDsaParams> SigningKey<P> {
}
}

/// Deterministically generate a signing key from the specified seed.
///
/// This method reflects the ML-DSA.KeyGen_internal algorithm from FIPS 204, but only returns a
/// signing key.
#[must_use]
pub fn from_seed(seed: &B32) -> Self {
let kp = P::from_seed(seed);
kp.signing_key
}

/// This method reflects the ML-DSA.Sign_internal algorithm from FIPS 204. It does not
/// include the domain separator that distinguishes between the normal and pre-hashed cases,
/// and it does not separate the context string from the rest of the message.
Expand Down Expand Up @@ -913,8 +923,9 @@ pub trait KeyGen: MlDsaParams {
fn key_gen<R: CryptoRng + ?Sized>(rng: &mut R) -> Self::KeyPair;

/// Deterministically generate a signing key pair from the specified seed
// TODO(RLB): Only expose this based on a feature.
fn key_gen_internal(xi: &B32) -> Self::KeyPair;
///
/// This method reflects the ML-DSA.KeyGen_internal algorithm from FIPS 204.
fn from_seed(xi: &B32) -> Self::KeyPair;
}

impl<P> KeyGen for P
Expand All @@ -929,12 +940,14 @@ where
fn key_gen<R: CryptoRng + ?Sized>(rng: &mut R) -> KeyPair<P> {
let mut xi = B32::default();
rng.fill_bytes(&mut xi);
Self::key_gen_internal(&xi)
Self::from_seed(&xi)
}

/// Deterministically generate a signing key pair from the specified seed
///
/// This method reflects the ML-DSA.KeyGen_internal algorithm from FIPS 204.
// Algorithm 6 ML-DSA.KeyGen_internal
fn key_gen_internal(xi: &B32) -> KeyPair<P>
fn from_seed(xi: &B32) -> KeyPair<P>
where
P: MlDsaParams,
{
Expand Down Expand Up @@ -1001,7 +1014,7 @@ mod test {
where
P: MlDsaParams + PartialEq,
{
let kp = P::key_gen_internal(&Array::default());
let kp = P::from_seed(&Array::default());
let sk = kp.signing_key;
let vk = kp.verifying_key;

Expand Down Expand Up @@ -1032,7 +1045,7 @@ mod test {
where
P: MlDsaParams + PartialEq,
{
let kp = P::key_gen_internal(&Array::default());
let kp = P::from_seed(&Array::default());
let sk = kp.signing_key;
let vk = kp.verifying_key;
let vk_derived = sk.verifying_key();
Expand All @@ -1051,7 +1064,7 @@ mod test {
where
P: MlDsaParams,
{
let kp = P::key_gen_internal(&Array::default());
let kp = P::from_seed(&Array::default());
let sk = kp.signing_key;
let vk = kp.verifying_key;

Expand Down Expand Up @@ -1084,7 +1097,7 @@ mod test {
let seed_data: &mut [u8] = seed.as_mut();
rng.fill(seed_data);

let kp = P::key_gen_internal(&seed);
let kp = P::from_seed(&seed);
let sk = kp.signing_key;
let vk = kp.verifying_key;

Expand Down Expand Up @@ -1113,7 +1126,7 @@ mod test {
where
P: MlDsaParams,
{
let kp = P::key_gen_internal(&Array::default());
let kp = P::from_seed(&Array::default());
let sk = kp.signing_key;
let vk = kp.verifying_key;

Expand All @@ -1135,7 +1148,7 @@ mod test {
where
P: MlDsaParams,
{
let kp = P::key_gen_internal(&Array::default());
let kp = P::from_seed(&Array::default());
let sk = kp.signing_key;
let vk = kp.verifying_key;

Expand All @@ -1157,7 +1170,7 @@ mod test {
where
P: MlDsaParams,
{
let kp = P::key_gen_internal(&Array::default());
let kp = P::from_seed(&Array::default());
let sk = kp.signing_key;
let vk = kp.verifying_key;

Expand All @@ -1172,4 +1185,22 @@ mod test {
sign_internal_verify_mu::<MlDsa65>();
sign_internal_verify_mu::<MlDsa87>();
}

#[test]
fn from_seed_implementations_match() {
fn assert_from_seed_equality<P>()
where
P: MlDsaParams,
{
let seed = Array([0u8; 32]);
let kp1 = P::from_seed(&seed);
let sk1 = SigningKey::<P>::from_seed(&seed);
let vk1 = sk1.verifying_key();
assert_eq!(kp1.signing_key, sk1);
assert_eq!(kp1.verifying_key, vk1);
}
assert_from_seed_equality::<MlDsa44>();
assert_from_seed_equality::<MlDsa65>();
assert_from_seed_equality::<MlDsa87>();
}
}
2 changes: 1 addition & 1 deletion ml-dsa/tests/key-gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ fn verify<P: MlDsaParams>(tc: &acvp::TestCase) {
let vk_bytes = EncodedVerifyingKey::<P>::try_from(tc.pk.as_slice()).unwrap();
let sk_bytes = EncodedSigningKey::<P>::try_from(tc.sk.as_slice()).unwrap();

let kp = P::key_gen_internal(&seed);
let kp = P::from_seed(&seed);
let sk = kp.signing_key().clone();
let vk = kp.verifying_key().clone();

Expand Down
6 changes: 3 additions & 3 deletions ml-dsa/tests/proptests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@ const MSG: &[u8] = b"Hello world";
// Keypairs
prop_compose! {
fn mldsa44_keypair()(seed_bytes in any::<[u8; 32]>()) -> KeyPair<MlDsa44> {
MlDsa44::key_gen_internal(seed_bytes.as_array_ref())
MlDsa44::from_seed(seed_bytes.as_array_ref())
}
}
prop_compose! {
fn mldsa65_keypair()(seed_bytes in any::<[u8; 32]>()) -> KeyPair<MlDsa65> {
MlDsa65::key_gen_internal(seed_bytes.as_array_ref())
MlDsa65::from_seed(seed_bytes.as_array_ref())
}
}
prop_compose! {
fn mldsa87_keypair()(seed_bytes in any::<[u8; 32]>()) -> KeyPair<MlDsa87> {
MlDsa87::key_gen_internal(seed_bytes.as_array_ref())
MlDsa87::from_seed(seed_bytes.as_array_ref())
}
}

Expand Down