Skip to content

PSS: Improve interoperability with optional auto salt length detection during verification #546

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 55 additions & 9 deletions src/algorithms/pss.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

use alloc::vec::Vec;
use digest::{Digest, DynDigest, FixedOutputReset};
use subtle::{Choice, ConstantTimeEq};
use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};

use super::mgf::{mgf1_xor, mgf1_xor_digest};
use crate::errors::{Error, Result};
Expand Down Expand Up @@ -170,7 +170,7 @@ fn emsa_pss_verify_pre<'a>(
m_hash: &[u8],
em: &'a mut [u8],
em_bits: usize,
s_len: usize,
s_len: Option<usize>,
h_len: usize,
) -> Result<(&'a mut [u8], &'a mut [u8])> {
// 1. If the length of M is greater than the input limitation for the
Expand All @@ -182,10 +182,12 @@ fn emsa_pss_verify_pre<'a>(
return Err(Error::Verification);
}

// 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop.
let em_len = em.len(); //(em_bits + 7) / 8;
if em_len < h_len + s_len + 2 {
return Err(Error::Verification);
if let Some(s_len) = s_len {
// 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop.
if em_len < h_len + s_len + 2 {
return Err(Error::Verification);
}
}

// 4. If the rightmost octet of EM does not have hexadecimal value
Expand Down Expand Up @@ -227,10 +229,48 @@ fn emsa_pss_verify_salt(db: &[u8], em_len: usize, s_len: usize, h_len: usize) ->
valid & rest[0].ct_eq(&0x01)
}

/// Detect salt length by scanning DB for the 0x01 separator byte.
/// Returns (s_len, valid) where s_len is 0 on failure.
fn emsa_pss_get_salt_len(db: &[u8], em_len: usize, h_len: usize) -> (usize, Choice) {
let em_len = em_len as u32;
let h_len = h_len as u32;
let max_scan_len = em_len - h_len - 2;

let mut separator_pos = 0u32;
let mut found_separator = Choice::from(0u8);
let mut padding_valid = Choice::from(1u8);

// Single forward scan to find separator and validate padding
for i in 0..=max_scan_len {
let byte_val = db[i as usize];
let is_zero = byte_val.ct_eq(&0x00);
let is_separator = byte_val.ct_eq(&0x01);
let is_invalid = !(is_zero | is_separator);

// Update separator position if we found one and haven't found one before
let should_update_pos = is_separator & !found_separator;
separator_pos = u32::conditional_select(&separator_pos, &i, should_update_pos);
found_separator =
Choice::conditional_select(&found_separator, &Choice::from(1u8), should_update_pos);

// Padding is invalid if we see a non-zero, non-separator byte before finding separator
let corrupts_padding = is_invalid & !found_separator;
padding_valid &= !corrupts_padding;
}

let salt_len = max_scan_len.wrapping_sub(separator_pos);
let final_valid = found_separator & padding_valid;

// Return 0 length on failure
let result_len = u32::conditional_select(&0u32, &salt_len, final_valid);

(result_len as usize, final_valid)
}

pub(crate) fn emsa_pss_verify(
m_hash: &[u8],
em: &mut [u8],
s_len: usize,
s_len: Option<usize>,
hash: &mut dyn DynDigest,
key_bits: usize,
) -> Result<()> {
Expand All @@ -252,7 +292,10 @@ pub(crate) fn emsa_pss_verify(
// to zero.
db[0] &= 0xFF >> /*uint*/(8 * em_len - em_bits);

let salt_valid = emsa_pss_verify_salt(db, em_len, s_len, h_len);
let (s_len, salt_valid) = match s_len {
Some(s_len) => (s_len, emsa_pss_verify_salt(db, em_len, s_len, h_len)),
None => emsa_pss_get_salt_len(db, em_len, h_len),
};

// 11. Let salt be the last s_len octets of DB.
let salt = &db[db.len() - s_len..];
Expand Down Expand Up @@ -281,7 +324,7 @@ pub(crate) fn emsa_pss_verify(
pub(crate) fn emsa_pss_verify_digest<D>(
m_hash: &[u8],
em: &mut [u8],
s_len: usize,
s_len: Option<usize>,
key_bits: usize,
) -> Result<()>
where
Expand All @@ -307,7 +350,10 @@ where
// to zero.
db[0] &= 0xFF >> /*uint*/(8 * em_len - em_bits);

let salt_valid = emsa_pss_verify_salt(db, em_len, s_len, h_len);
let (s_len, salt_valid) = match s_len {
Some(s_len) => (s_len, emsa_pss_verify_salt(db, em_len, s_len, h_len)),
None => emsa_pss_get_salt_len(db, em_len, h_len),
};

// 11. Let salt be the last s_len octets of DB.
let salt = &db[db.len() - s_len..];
Expand Down
49 changes: 43 additions & 6 deletions src/pss.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ pub struct Pss {
pub digest: Box<dyn DynDigest + Send + Sync>,

/// Salt length.
pub salt_len: usize,
/// Required for signing, optional for verifying.
pub salt_len: Option<usize>,
}

impl Pss {
Expand All @@ -66,7 +67,7 @@ impl Pss {
Self {
blinded: false,
digest: Box::new(T::new()),
salt_len: len,
salt_len: Some(len),
}
}

Expand All @@ -84,7 +85,7 @@ impl Pss {
Self {
blinded: true,
digest: Box::new(T::new()),
salt_len: len,
salt_len: Some(len),
}
}
}
Expand All @@ -101,7 +102,7 @@ impl SignatureScheme for Pss {
self.blinded,
priv_key,
hashed,
self.salt_len,
self.salt_len.expect("salt_len to be Some"),
&mut *self.digest,
)
}
Expand Down Expand Up @@ -134,7 +135,7 @@ pub(crate) fn verify(
sig: &BoxedUint,
sig_len: usize,
digest: &mut dyn DynDigest,
salt_len: usize,
salt_len: Option<usize>,
) -> Result<()> {
if sig_len != pub_key.size() {
return Err(Error::Verification);
Expand All @@ -149,7 +150,7 @@ pub(crate) fn verify_digest<D>(
pub_key: &RsaPublicKey,
hashed: &[u8],
sig: &BoxedUint,
salt_len: usize,
salt_len: Option<usize>,
) -> Result<()>
where
D: Digest + FixedOutputReset,
Expand Down Expand Up @@ -261,6 +262,7 @@ mod test {
use crate::pss::{BlindedSigningKey, Pss, Signature, SigningKey, VerifyingKey};
use crate::{RsaPrivateKey, RsaPublicKey};

use crate::traits::PublicKeyParts;
use hex_literal::hex;
use pkcs1::DecodeRsaPrivateKey;
use rand_chacha::{rand_core::SeedableRng, ChaCha8Rng};
Expand Down Expand Up @@ -622,4 +624,39 @@ tAboUGBxTDq3ZroNism3DaMIbKPyYrAqhKov1h5V
.expect("failed to verify");
}
}

#[test]
// Tests the case where the salt length used for signing differs from the default length
// while the verifier uses auto-detection.
fn test_sign_and_verify_pss_differing_salt_len() {
let priv_key = get_private_key();

let tests = ["test\n"];
let mut rng = ChaCha8Rng::from_seed([42; 32]);

// signing keys using different salt lengths
let signing_keys = [
// default salt length
SigningKey::<Sha1>::new(priv_key.clone()),
// maximum salt length
SigningKey::<Sha1>::new_with_salt_len(
priv_key.clone(),
priv_key.size() - Sha1::output_size() - 2,
),
// unsalted
SigningKey::<Sha1>::new_with_salt_len(priv_key.clone(), 0),
];

// verifying key uses default salt length strategy
let verifying_key = VerifyingKey::<Sha1>::new_with_auto_salt_len(priv_key.to_public_key());

for test in tests {
for signing_key in &signing_keys {
let sig = signing_key.sign_with_rng(&mut rng, test.as_bytes());
verifying_key
.verify(test.as_bytes(), &sig)
.expect("verification to succeed");
}
}
}
}
2 changes: 1 addition & 1 deletion src/pss/blinded_signing_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ where
fn verifying_key(&self) -> Self::VerifyingKey {
VerifyingKey {
inner: self.inner.to_public_key(),
salt_len: self.salt_len,
salt_len: Some(self.salt_len),
phantom: Default::default(),
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/pss/signing_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ where
fn verifying_key(&self) -> Self::VerifyingKey {
VerifyingKey {
inner: self.inner.to_public_key(),
salt_len: self.salt_len,
salt_len: Some(self.salt_len),
phantom: Default::default(),
}
}
Expand Down
16 changes: 13 additions & 3 deletions src/pss/verifying_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ where
D: Digest,
{
pub(super) inner: RsaPublicKey,
pub(super) salt_len: usize,
pub(super) salt_len: Option<usize>,
pub(super) phantom: PhantomData<D>,
}

Expand All @@ -45,13 +45,23 @@ where
pub fn new_with_salt_len(key: RsaPublicKey, salt_len: usize) -> Self {
Self {
inner: key,
salt_len,
salt_len: Some(salt_len),
phantom: Default::default(),
}
}

/// Create a new RSASSA-PSS verifying key.
/// Attempts to automatically detect the salt length.
pub fn new_with_auto_salt_len(key: RsaPublicKey) -> Self {
Self {
inner: key,
salt_len: None,
phantom: Default::default(),
}
}

/// Return specified salt length for this key
pub fn salt_len(&self) -> usize {
pub fn salt_len(&self) -> Option<usize> {
self.salt_len
}
}
Expand Down
Loading