diff --git a/src/algorithms/pss.rs b/src/algorithms/pss.rs index 78cdd27..99c6199 100644 --- a/src/algorithms/pss.rs +++ b/src/algorithms/pss.rs @@ -11,7 +11,7 @@ use alloc::vec::Vec; use digest::{Digest, DynDigest, FixedOutputReset}; -use subtle::{Choice, ConstantTimeEq}; +use subtle::{Choice, ConditionallySelectable, ConstantTimeEq}; use super::mgf::{mgf1_xor, mgf1_xor_digest}; use crate::errors::{Error, Result}; @@ -170,7 +170,7 @@ fn emsa_pss_verify_pre<'a>( m_hash: &[u8], em: &'a mut [u8], em_bits: usize, - s_len: usize, + s_len: Option, h_len: usize, ) -> Result<(&'a mut [u8], &'a mut [u8])> { // 1. If the length of M is greater than the input limitation for the @@ -182,10 +182,12 @@ fn emsa_pss_verify_pre<'a>( return Err(Error::Verification); } - // 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop. let em_len = em.len(); //(em_bits + 7) / 8; - if em_len < h_len + s_len + 2 { - return Err(Error::Verification); + if let Some(s_len) = s_len { + // 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop. + if em_len < h_len + s_len + 2 { + return Err(Error::Verification); + } } // 4. If the rightmost octet of EM does not have hexadecimal value @@ -227,10 +229,48 @@ fn emsa_pss_verify_salt(db: &[u8], em_len: usize, s_len: usize, h_len: usize) -> valid & rest[0].ct_eq(&0x01) } +/// Detect salt length by scanning DB for the 0x01 separator byte. +/// Returns (s_len, valid) where s_len is 0 on failure. +fn emsa_pss_get_salt_len(db: &[u8], em_len: usize, h_len: usize) -> (usize, Choice) { + let em_len = em_len as u32; + let h_len = h_len as u32; + let max_scan_len = em_len - h_len - 2; + + let mut separator_pos = 0u32; + let mut found_separator = Choice::from(0u8); + let mut padding_valid = Choice::from(1u8); + + // Single forward scan to find separator and validate padding + for i in 0..=max_scan_len { + let byte_val = db[i as usize]; + let is_zero = byte_val.ct_eq(&0x00); + let is_separator = byte_val.ct_eq(&0x01); + let is_invalid = !(is_zero | is_separator); + + // Update separator position if we found one and haven't found one before + let should_update_pos = is_separator & !found_separator; + separator_pos = u32::conditional_select(&separator_pos, &i, should_update_pos); + found_separator = + Choice::conditional_select(&found_separator, &Choice::from(1u8), should_update_pos); + + // Padding is invalid if we see a non-zero, non-separator byte before finding separator + let corrupts_padding = is_invalid & !found_separator; + padding_valid &= !corrupts_padding; + } + + let salt_len = max_scan_len.wrapping_sub(separator_pos); + let final_valid = found_separator & padding_valid; + + // Return 0 length on failure + let result_len = u32::conditional_select(&0u32, &salt_len, final_valid); + + (result_len as usize, final_valid) +} + pub(crate) fn emsa_pss_verify( m_hash: &[u8], em: &mut [u8], - s_len: usize, + s_len: Option, hash: &mut dyn DynDigest, key_bits: usize, ) -> Result<()> { @@ -252,7 +292,10 @@ pub(crate) fn emsa_pss_verify( // to zero. db[0] &= 0xFF >> /*uint*/(8 * em_len - em_bits); - let salt_valid = emsa_pss_verify_salt(db, em_len, s_len, h_len); + let (s_len, salt_valid) = match s_len { + Some(s_len) => (s_len, emsa_pss_verify_salt(db, em_len, s_len, h_len)), + None => emsa_pss_get_salt_len(db, em_len, h_len), + }; // 11. Let salt be the last s_len octets of DB. let salt = &db[db.len() - s_len..]; @@ -281,7 +324,7 @@ pub(crate) fn emsa_pss_verify( pub(crate) fn emsa_pss_verify_digest( m_hash: &[u8], em: &mut [u8], - s_len: usize, + s_len: Option, key_bits: usize, ) -> Result<()> where @@ -307,7 +350,10 @@ where // to zero. db[0] &= 0xFF >> /*uint*/(8 * em_len - em_bits); - let salt_valid = emsa_pss_verify_salt(db, em_len, s_len, h_len); + let (s_len, salt_valid) = match s_len { + Some(s_len) => (s_len, emsa_pss_verify_salt(db, em_len, s_len, h_len)), + None => emsa_pss_get_salt_len(db, em_len, h_len), + }; // 11. Let salt be the last s_len octets of DB. let salt = &db[db.len() - s_len..]; diff --git a/src/pss.rs b/src/pss.rs index 3ef0e42..227deb0 100644 --- a/src/pss.rs +++ b/src/pss.rs @@ -51,7 +51,8 @@ pub struct Pss { pub digest: Box, /// Salt length. - pub salt_len: usize, + /// Required for signing, optional for verifying. + pub salt_len: Option, } impl Pss { @@ -66,7 +67,7 @@ impl Pss { Self { blinded: false, digest: Box::new(T::new()), - salt_len: len, + salt_len: Some(len), } } @@ -84,7 +85,7 @@ impl Pss { Self { blinded: true, digest: Box::new(T::new()), - salt_len: len, + salt_len: Some(len), } } } @@ -101,7 +102,7 @@ impl SignatureScheme for Pss { self.blinded, priv_key, hashed, - self.salt_len, + self.salt_len.expect("salt_len to be Some"), &mut *self.digest, ) } @@ -134,7 +135,7 @@ pub(crate) fn verify( sig: &BoxedUint, sig_len: usize, digest: &mut dyn DynDigest, - salt_len: usize, + salt_len: Option, ) -> Result<()> { if sig_len != pub_key.size() { return Err(Error::Verification); @@ -149,7 +150,7 @@ pub(crate) fn verify_digest( pub_key: &RsaPublicKey, hashed: &[u8], sig: &BoxedUint, - salt_len: usize, + salt_len: Option, ) -> Result<()> where D: Digest + FixedOutputReset, @@ -261,6 +262,7 @@ mod test { use crate::pss::{BlindedSigningKey, Pss, Signature, SigningKey, VerifyingKey}; use crate::{RsaPrivateKey, RsaPublicKey}; + use crate::traits::PublicKeyParts; use hex_literal::hex; use pkcs1::DecodeRsaPrivateKey; use rand_chacha::{rand_core::SeedableRng, ChaCha8Rng}; @@ -622,4 +624,39 @@ tAboUGBxTDq3ZroNism3DaMIbKPyYrAqhKov1h5V .expect("failed to verify"); } } + + #[test] + // Tests the case where the salt length used for signing differs from the default length + // while the verifier uses auto-detection. + fn test_sign_and_verify_pss_differing_salt_len() { + let priv_key = get_private_key(); + + let tests = ["test\n"]; + let mut rng = ChaCha8Rng::from_seed([42; 32]); + + // signing keys using different salt lengths + let signing_keys = [ + // default salt length + SigningKey::::new(priv_key.clone()), + // maximum salt length + SigningKey::::new_with_salt_len( + priv_key.clone(), + priv_key.size() - Sha1::output_size() - 2, + ), + // unsalted + SigningKey::::new_with_salt_len(priv_key.clone(), 0), + ]; + + // verifying key uses default salt length strategy + let verifying_key = VerifyingKey::::new_with_auto_salt_len(priv_key.to_public_key()); + + for test in tests { + for signing_key in &signing_keys { + let sig = signing_key.sign_with_rng(&mut rng, test.as_bytes()); + verifying_key + .verify(test.as_bytes(), &sig) + .expect("verification to succeed"); + } + } + } } diff --git a/src/pss/blinded_signing_key.rs b/src/pss/blinded_signing_key.rs index 3c1b7a6..518f0f4 100644 --- a/src/pss/blinded_signing_key.rs +++ b/src/pss/blinded_signing_key.rs @@ -219,7 +219,7 @@ where fn verifying_key(&self) -> Self::VerifyingKey { VerifyingKey { inner: self.inner.to_public_key(), - salt_len: self.salt_len, + salt_len: Some(self.salt_len), phantom: Default::default(), } } diff --git a/src/pss/signing_key.rs b/src/pss/signing_key.rs index b67d86d..637c514 100644 --- a/src/pss/signing_key.rs +++ b/src/pss/signing_key.rs @@ -251,7 +251,7 @@ where fn verifying_key(&self) -> Self::VerifyingKey { VerifyingKey { inner: self.inner.to_public_key(), - salt_len: self.salt_len, + salt_len: Some(self.salt_len), phantom: Default::default(), } } diff --git a/src/pss/verifying_key.rs b/src/pss/verifying_key.rs index de96a1f..9303885 100644 --- a/src/pss/verifying_key.rs +++ b/src/pss/verifying_key.rs @@ -27,7 +27,7 @@ where D: Digest, { pub(super) inner: RsaPublicKey, - pub(super) salt_len: usize, + pub(super) salt_len: Option, pub(super) phantom: PhantomData, } @@ -45,13 +45,23 @@ where pub fn new_with_salt_len(key: RsaPublicKey, salt_len: usize) -> Self { Self { inner: key, - salt_len, + salt_len: Some(salt_len), + phantom: Default::default(), + } + } + + /// Create a new RSASSA-PSS verifying key. + /// Attempts to automatically detect the salt length. + pub fn new_with_auto_salt_len(key: RsaPublicKey) -> Self { + Self { + inner: key, + salt_len: None, phantom: Default::default(), } } /// Return specified salt length for this key - pub fn salt_len(&self) -> usize { + pub fn salt_len(&self) -> Option { self.salt_len } }