diff --git a/Cargo.lock b/Cargo.lock index 75e58c4..e819989 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9,18 +9,25 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2d7e60934ceec538daadb9d8432424ed043a904d8e0243f3c6446bce549a46ac" [[package]] -name = "pldm-lib" +name = "pldm-common" version = "0.1.0" dependencies = [ "bitfield", "zerocopy", ] +[[package]] +name = "pldm-lib" +version = "0.1.0" +dependencies = [ + "pldm-common", +] + [[package]] name = "proc-macro2" -version = "1.0.95" +version = "1.0.101" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778" +checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de" dependencies = [ "unicode-ident", ] @@ -36,9 +43,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.104" +version = "2.0.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17b6f705963418cdb9927482fa304bc562ece2fdd4f616084c50b7023b435a40" +checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6" dependencies = [ "proc-macro2", "quote", @@ -47,24 +54,24 @@ dependencies = [ [[package]] name = "unicode-ident" -version = "1.0.18" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" +checksum = "f63a545481291138910575129486daeaf8ac54aee4387fe7906919f7830c7d9d" [[package]] name = "zerocopy" -version = "0.8.26" +version = "0.8.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1039dd0d3c310cf05de012d8a39ff557cb0d23087fd44cad61df08fc31907a2f" +checksum = "0894878a5fa3edfd6da3f88c4805f4c8558e2b996227a3d864f47fe11e38282c" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.26" +version = "0.8.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ecf5b4cc5364572d7f4c329661bcc82724222973f2cab6f050a4e5c22f75181" +checksum = "88d2b8d9c68ad2b9e4340d7832716a4d21a22a1154777ad56ea55c51a9cf3831" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index d798a4b..afc0e6b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,8 +2,22 @@ name = "pldm-lib" version = "0.1.0" authors = ["Caliptra contributors", "OpenPRoT contributors"] -edition = "2024" +edition = "2021" -[dependencies] -zerocopy = {version = "0.8.17", features = ["derive"]} +[workspace] +members = [ + "pldm-common" +] + +[workspace.package] +version = "0.1.0" +authors = ["Caliptra contributors", "OpenPRoT contributors"] +edition = "2021" + +[workspace.dependencies] +pldm-common = { path = "pldm-common" } bitfield = "0.14.0" +zerocopy = { version = "0.8.17", features = ["derive"] } + +[dependencies] +pldm-common.workspace = true diff --git a/pldm-common/Cargo.toml b/pldm-common/Cargo.toml new file mode 100644 index 0000000..ac80d23 --- /dev/null +++ b/pldm-common/Cargo.toml @@ -0,0 +1,12 @@ +# Licensed under the Apache-2.0 license + +[package] +name = "pldm-common" +version.workspace = true +edition.workspace = true +authors.workspace = true + +[dependencies] +zerocopy.workspace = true +bitfield.workspace = true + diff --git a/src/codec.rs b/pldm-common/src/codec.rs similarity index 100% rename from src/codec.rs rename to pldm-common/src/codec.rs diff --git a/pldm-common/src/error.rs b/pldm-common/src/error.rs new file mode 100644 index 0000000..17dd6a2 --- /dev/null +++ b/pldm-common/src/error.rs @@ -0,0 +1,36 @@ +// Licensed under the Apache-2.0 license + +#[derive(Debug, Clone, PartialEq)] +pub enum PldmError { + InvalidData, + InvalidLength, + InvalidMsgType, + InvalidProtocolVersion, + UnsupportedCmd, + UnsupportedPldmType, + InvalidCompletionCode, + InvalidTransferOpFlag, + InvalidTransferRespFlag, + + InvalidVersionStringType, + InvalidVersionStringLength, + InvalidFdState, + InvalidDescriptorType, + InvalidDescriptorLength, + InvalidDescriptorCount, + InvalidComponentClassification, + InvalidComponentResponseCode, + InvalidComponentCompatibilityResponse, + InvalidComponentCompatibilityResponseCode, + InvalidTransferResult, + InvalidVerifyResult, + InvalidApplyResult, + InvalidGetStatusReasonCode, + InvalidAuxStateStatus, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum UtilError { + InvalidMctpPayloadLength, + InvalidMctpMsgType, +} diff --git a/pldm-common/src/lib.rs b/pldm-common/src/lib.rs new file mode 100644 index 0000000..c57ad8f --- /dev/null +++ b/pldm-common/src/lib.rs @@ -0,0 +1,16 @@ +#![no_std] + +// Re-export core types for no_std compatibility +pub use core::{ + option::Option::{self, Some, None}, + result::Result::{self, Ok, Err}, +}; + +// Licensed under the Apache-2.0 license + + +pub mod codec; +pub mod error; +pub mod message; +pub mod protocol; +pub mod util; \ No newline at end of file diff --git a/src/message/control.rs b/pldm-common/src/message/control.rs similarity index 100% rename from src/message/control.rs rename to pldm-common/src/message/control.rs diff --git a/src/message/firmware_update/activate_fw.rs b/pldm-common/src/message/firmware_update/activate_fw.rs similarity index 100% rename from src/message/firmware_update/activate_fw.rs rename to pldm-common/src/message/firmware_update/activate_fw.rs diff --git a/src/message/firmware_update/apply_complete.rs b/pldm-common/src/message/firmware_update/apply_complete.rs similarity index 100% rename from src/message/firmware_update/apply_complete.rs rename to pldm-common/src/message/firmware_update/apply_complete.rs diff --git a/src/message/firmware_update/get_fw_params.rs b/pldm-common/src/message/firmware_update/get_fw_params.rs similarity index 100% rename from src/message/firmware_update/get_fw_params.rs rename to pldm-common/src/message/firmware_update/get_fw_params.rs diff --git a/src/message/firmware_update/get_status.rs b/pldm-common/src/message/firmware_update/get_status.rs similarity index 100% rename from src/message/firmware_update/get_status.rs rename to pldm-common/src/message/firmware_update/get_status.rs diff --git a/src/message/firmware_update/mod.rs b/pldm-common/src/message/firmware_update/mod.rs similarity index 100% rename from src/message/firmware_update/mod.rs rename to pldm-common/src/message/firmware_update/mod.rs diff --git a/src/message/firmware_update/pass_component.rs b/pldm-common/src/message/firmware_update/pass_component.rs similarity index 100% rename from src/message/firmware_update/pass_component.rs rename to pldm-common/src/message/firmware_update/pass_component.rs diff --git a/src/message/firmware_update/query_devid.rs b/pldm-common/src/message/firmware_update/query_devid.rs similarity index 100% rename from src/message/firmware_update/query_devid.rs rename to pldm-common/src/message/firmware_update/query_devid.rs diff --git a/src/message/firmware_update/request_cancel.rs b/pldm-common/src/message/firmware_update/request_cancel.rs similarity index 100% rename from src/message/firmware_update/request_cancel.rs rename to pldm-common/src/message/firmware_update/request_cancel.rs diff --git a/src/message/firmware_update/request_fw_data.rs b/pldm-common/src/message/firmware_update/request_fw_data.rs similarity index 100% rename from src/message/firmware_update/request_fw_data.rs rename to pldm-common/src/message/firmware_update/request_fw_data.rs diff --git a/src/message/firmware_update/request_update.rs b/pldm-common/src/message/firmware_update/request_update.rs similarity index 100% rename from src/message/firmware_update/request_update.rs rename to pldm-common/src/message/firmware_update/request_update.rs diff --git a/src/message/firmware_update/transfer_complete.rs b/pldm-common/src/message/firmware_update/transfer_complete.rs similarity index 100% rename from src/message/firmware_update/transfer_complete.rs rename to pldm-common/src/message/firmware_update/transfer_complete.rs diff --git a/src/message/firmware_update/update_component.rs b/pldm-common/src/message/firmware_update/update_component.rs similarity index 100% rename from src/message/firmware_update/update_component.rs rename to pldm-common/src/message/firmware_update/update_component.rs diff --git a/src/message/firmware_update/verify_complete.rs b/pldm-common/src/message/firmware_update/verify_complete.rs similarity index 100% rename from src/message/firmware_update/verify_complete.rs rename to pldm-common/src/message/firmware_update/verify_complete.rs diff --git a/src/message/mod.rs b/pldm-common/src/message/mod.rs similarity index 100% rename from src/message/mod.rs rename to pldm-common/src/message/mod.rs diff --git a/src/protocol/base.rs b/pldm-common/src/protocol/base.rs similarity index 100% rename from src/protocol/base.rs rename to pldm-common/src/protocol/base.rs diff --git a/src/protocol/firmware_update.rs b/pldm-common/src/protocol/firmware_update.rs similarity index 99% rename from src/protocol/firmware_update.rs rename to pldm-common/src/protocol/firmware_update.rs index f959f50..4033e96 100644 --- a/src/protocol/firmware_update.rs +++ b/pldm-common/src/protocol/firmware_update.rs @@ -268,7 +268,8 @@ pub fn get_descriptor_length(descriptor_type: DescriptorType) -> usize { } #[derive(Debug, Copy, Clone, PartialEq)] -#[repr(C)] +#[derive(FromBytes, IntoBytes)] +//#[repr(C)] pub struct Descriptor { pub descriptor_type: u16, pub descriptor_length: u16, diff --git a/src/protocol/mod.rs b/pldm-common/src/protocol/mod.rs similarity index 100% rename from src/protocol/mod.rs rename to pldm-common/src/protocol/mod.rs diff --git a/src/protocol/version.rs b/pldm-common/src/protocol/version.rs similarity index 100% rename from src/protocol/version.rs rename to pldm-common/src/protocol/version.rs diff --git a/src/util/fw_component.rs b/pldm-common/src/util/fw_component.rs similarity index 100% rename from src/util/fw_component.rs rename to pldm-common/src/util/fw_component.rs diff --git a/src/util/mctp_transport.rs b/pldm-common/src/util/mctp_transport.rs similarity index 100% rename from src/util/mctp_transport.rs rename to pldm-common/src/util/mctp_transport.rs diff --git a/src/util/mod.rs b/pldm-common/src/util/mod.rs similarity index 100% rename from src/util/mod.rs rename to pldm-common/src/util/mod.rs diff --git a/src/Cargo.toml b/src/Cargo.toml new file mode 100644 index 0000000..403f5d5 --- /dev/null +++ b/src/Cargo.toml @@ -0,0 +1,12 @@ +# Licensed under the Apache-2.0 license + +[package] +name = "pldm-service" +version.workspace = true +edition.workspace = true +authors.workspace = true + +[dependencies] +zerocopy.workspace = true +bitfield.workspace = true +pldm-common.workspace = true diff --git a/src/cmd_interface.rs b/src/cmd_interface.rs new file mode 100644 index 0000000..3a0b94f --- /dev/null +++ b/src/cmd_interface.rs @@ -0,0 +1,217 @@ +// Licensed under the Apache-2.0 license +use crate::control_context::{ControlContext, CtrlCmdResponder, ProtocolCapability}; +use crate::error::MsgHandlerError; +use crate::firmware_device::fd_context::FirmwareDeviceContext; +use crate::transport::{self, MctpTransport}; +use core::sync::atomic::{AtomicBool, Ordering}; +use pldm_common::codec::PldmCodec; +use pldm_common::protocol::base::{ + PldmBaseCompletionCode, PldmControlCmd, PldmFailureResponse, PldmMsgHeader, PldmSupportedType, +}; +use pldm_common::protocol::firmware_update::FwUpdateCmd; +use pldm_common::util::mctp_transport::{ + construct_mctp_pldm_msg, extract_pldm_msg, PLDM_MSG_OFFSET, +}; + +pub type PldmCompletionErrorCode = u8; + +// Helper function to write a failure response message into payload +pub(crate) fn generate_failure_response( + payload: &mut [u8], + completion_code: u8, +) -> Result { + let header = PldmMsgHeader::decode(payload).map_err(MsgHandlerError::Codec)?; + let resp = PldmFailureResponse { + hdr: header.into_response(), + completion_code, + }; + resp.encode(payload).map_err(MsgHandlerError::Codec) +} + +pub struct CmdInterface<'a, MctpTransport> { + ctrl_ctx: ControlContext<'a>, + fd_ctx: FirmwareDeviceContext, + transport: &'a mut MctpTransport, + busy: AtomicBool, +} + +impl<'a> CmdInterface<'a, MctpTransport> { + pub fn new( + protocol_capabilities: &'a [ProtocolCapability], + transport: &'a mut MctpTransport, + ) -> Self { + let ctrl_ctx = ControlContext::new(protocol_capabilities); + let fd_ctx = FirmwareDeviceContext::new(); + Self { + ctrl_ctx, + fd_ctx, + transport, + busy: AtomicBool::new(false), + } + } + + pub fn handle_responder_msg( + &mut self, + msg_buf: &mut [u8], + ) -> Result<(), MsgHandlerError> { + // Receive msg from mctp transport + self.transport + .receive_request(msg_buf) + .map_err(MsgHandlerError::Transport)?; + + // Process the request + let resp_len = self.process_request(msg_buf)?; + + // Send the response + self.transport + .send_response(&msg_buf[..resp_len]) + .map_err(MsgHandlerError::Transport) + } + + pub fn handle_initiator_msg( + &mut self, + msg_buf: &mut [u8], + ) -> Result<(), MsgHandlerError> { + // Retrieve the UA EID from the configuration + let ua_eid: u8 = crate::config::UA_EID; + + // Prepare the request payload + let payload = construct_mctp_pldm_msg(msg_buf).map_err(MsgHandlerError::Util)?; + let reserved_len = PLDM_MSG_OFFSET; + + // Generate the request + let req_len = self.fd_ctx.fd_progress(payload)?; + if req_len == 0 { + return Ok(()); + } + + // Send the request + self.transport + .send_request(ua_eid, &msg_buf[..req_len + reserved_len]) + .map_err(MsgHandlerError::Transport)?; + + // Wait for and process the response + self.transport + .receive_response(msg_buf) + .map_err(MsgHandlerError::Transport)?; + + let payload = extract_pldm_msg(msg_buf).map_err(MsgHandlerError::Util)?; + + // Handle the response + self.fd_ctx.handle_response(payload)?; + + Ok(()) + } + + fn process_request(&mut self, msg_buf: &mut [u8]) -> Result { + // Check if the handler is busy processing a request + if self.busy.load(Ordering::SeqCst) { + return Err(MsgHandlerError::NotReady); + } + + self.busy.store(true, Ordering::SeqCst); + + // Get the pldm payload from msg_buf + let payload = &mut msg_buf[PLDM_MSG_OFFSET..]; + let reserved_len = PLDM_MSG_OFFSET; + + let (pldm_type, cmd_opcode) = match self.preprocess_request(payload) { + Ok(result) => result, + Err(e) => { + self.busy.store(false, Ordering::SeqCst); + return Ok(reserved_len + generate_failure_response(payload, e)?); + } + }; + + let resp_len = match pldm_type { + PldmSupportedType::Base => self.process_control_cmd(cmd_opcode, payload), + PldmSupportedType::FwUpdate => self.process_fw_update_cmd(cmd_opcode, payload), + _ => { + unreachable!() + } + }; + + self.busy.store(false, Ordering::SeqCst); + + match resp_len { + Ok(bytes) => Ok(reserved_len + bytes), + Err(e) => Err(e), + } + } + + fn process_control_cmd( + &self, + cmd_opcode: u8, + payload: &mut [u8], + ) -> Result { + match PldmControlCmd::try_from(cmd_opcode) { + Ok(cmd) => match cmd { + PldmControlCmd::GetTid => self.ctrl_ctx.get_tid_rsp(payload), + PldmControlCmd::SetTid => self.ctrl_ctx.set_tid_rsp(payload), + PldmControlCmd::GetPldmTypes => self.ctrl_ctx.get_pldm_types_rsp(payload), + PldmControlCmd::GetPldmCommands => self.ctrl_ctx.get_pldm_commands_rsp(payload), + PldmControlCmd::GetPldmVersion => self.ctrl_ctx.get_pldm_version_rsp(payload), + }, + Err(_) => { + generate_failure_response(payload, PldmBaseCompletionCode::UnsupportedPldmCmd as u8) + } + } + } + + fn process_fw_update_cmd( + &mut self, + cmd_opcode: u8, + payload: &mut [u8], + ) -> Result { + match FwUpdateCmd::try_from(cmd_opcode) { + Ok(cmd) => match cmd { + FwUpdateCmd::QueryDeviceIdentifiers => self.fd_ctx.query_devid_rsp(payload), + FwUpdateCmd::GetFirmwareParameters => { + self.fd_ctx.get_firmware_parameters_rsp(payload) + } + FwUpdateCmd::RequestUpdate => self.fd_ctx.request_update_rsp(payload), + FwUpdateCmd::PassComponentTable => self.fd_ctx.pass_component_rsp(payload), + FwUpdateCmd::UpdateComponent => self.fd_ctx.update_component_rsp(payload), + + FwUpdateCmd::ActivateFirmware => self.fd_ctx.activate_firmware_rsp(payload), + FwUpdateCmd::CancelUpdateComponent => { + self.fd_ctx.cancel_update_component_rsp(payload) + } + FwUpdateCmd::CancelUpdate => self.fd_ctx.cancel_update_rsp(payload), + FwUpdateCmd::GetStatus => self.fd_ctx.get_status_rsp(payload), + _ => generate_failure_response( + payload, + PldmBaseCompletionCode::UnsupportedPldmCmd as u8, + ), + }, + Err(_) => { + generate_failure_response(payload, PldmBaseCompletionCode::UnsupportedPldmCmd as u8) + } + } + } + + fn preprocess_request( + &self, + payload: &[u8], + ) -> Result<(PldmSupportedType, u8), PldmCompletionErrorCode> { + let header = PldmMsgHeader::decode(payload) + .map_err(|_| PldmBaseCompletionCode::InvalidData as u8)?; + if !(header.is_request() && header.is_hdr_ver_valid()) { + Err(PldmBaseCompletionCode::InvalidData as u8)?; + } + + let pldm_type = PldmSupportedType::try_from(header.pldm_type()) + .map_err(|_| PldmBaseCompletionCode::InvalidPldmType as u8)?; + + if !self.ctrl_ctx.is_supported_type(pldm_type) { + Err(PldmBaseCompletionCode::InvalidPldmType as u8)?; + } + + let cmd_opcode = header.cmd_code(); + if self.ctrl_ctx.is_supported_command(pldm_type, cmd_opcode) { + Ok((pldm_type, cmd_opcode)) + } else { + Err(PldmBaseCompletionCode::UnsupportedPldmCmd as u8) + } + } +} diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..fd0089a --- /dev/null +++ b/src/config.rs @@ -0,0 +1,44 @@ +// Licensed under the Apache-2.0 license +use crate::control_context::ProtocolCapability; +use pldm_common::protocol::base::{PldmControlCmd, PldmSupportedType}; +use pldm_common::protocol::firmware_update::{FwUpdateCmd, PldmFdTime}; + +pub const PLDM_PROTOCOL_CAP_COUNT: usize = 2; +pub const FD_MAX_XFER_SIZE: usize = 512; // Arbitrary limit and change as needed. +pub const DEFAULT_FD_T1_TIMEOUT: PldmFdTime = 120000; // FD_T1 update mode idle timeout, range is [60s, 120s]. +pub const DEFAULT_FD_T2_RETRY_TIME: PldmFdTime = 5000; // FD_T2 retry request for firmware data, range is [1s, 5s]. +pub const INSTANCE_ID_COUNT: u8 = 32; +pub const UA_EID: u8 = 8; // Update Agent Endpoint ID for testing. + +pub static PLDM_PROTOCOL_CAPABILITIES: [ProtocolCapability<'static>; PLDM_PROTOCOL_CAP_COUNT] = [ + ProtocolCapability { + pldm_type: PldmSupportedType::Base, + protocol_version: 0xF1F1F000, //"1.1.0" + supported_commands: &[ + PldmControlCmd::SetTid as u8, + PldmControlCmd::GetTid as u8, + PldmControlCmd::GetPldmCommands as u8, + PldmControlCmd::GetPldmVersion as u8, + PldmControlCmd::GetPldmTypes as u8, + ], + }, + ProtocolCapability { + pldm_type: PldmSupportedType::FwUpdate, + protocol_version: 0xF1F3F000, // "1.3.0" + supported_commands: &[ + FwUpdateCmd::QueryDeviceIdentifiers as u8, + FwUpdateCmd::GetFirmwareParameters as u8, + FwUpdateCmd::RequestUpdate as u8, + FwUpdateCmd::PassComponentTable as u8, + FwUpdateCmd::UpdateComponent as u8, + FwUpdateCmd::RequestFirmwareData as u8, + FwUpdateCmd::TransferComplete as u8, + FwUpdateCmd::VerifyComplete as u8, + FwUpdateCmd::ApplyComplete as u8, + FwUpdateCmd::ActivateFirmware as u8, + FwUpdateCmd::GetStatus as u8, + FwUpdateCmd::CancelUpdateComponent as u8, + FwUpdateCmd::CancelUpdate as u8, + ], + }, +]; \ No newline at end of file diff --git a/src/control_context.rs b/src/control_context.rs new file mode 100644 index 0000000..814d379 --- /dev/null +++ b/src/control_context.rs @@ -0,0 +1,308 @@ +// Licensed under the Apache-2.0 license +use crate::cmd_interface::generate_failure_response; +use crate::error::MsgHandlerError; +use core::sync::atomic::{AtomicUsize, Ordering}; +use pldm_common::codec::PldmCodec; +use pldm_common::error::PldmError; +use pldm_common::message::control::{ + GetPldmCommandsRequest, GetPldmCommandsResponse, GetPldmTypeRequest, GetPldmTypeResponse, + GetPldmVersionRequest, GetPldmVersionResponse, GetTidRequest, GetTidResponse, SetTidRequest, + SetTidResponse, +}; +use pldm_common::protocol::base::{ + PldmBaseCompletionCode, PldmControlCompletionCode, PldmSupportedType, TransferOperationFlag, + TransferRespFlag, +}; +use pldm_common::protocol::version::{PldmVersion, ProtocolVersionStr, Ver32}; + +pub type Tid = u8; +pub type CmdOpCode = u8; +pub const UNASSIGNED_TID: Tid = 0; + +#[derive(Debug, Copy, Clone, PartialEq)] +pub struct ProtocolCapability<'a> { + pub pldm_type: PldmSupportedType, + pub protocol_version: Ver32, + pub supported_commands: &'a [CmdOpCode], +} + +impl<'a> ProtocolCapability<'a> { + pub fn new( + pldm_type: PldmSupportedType, + protocol_version: ProtocolVersionStr, + supported_commands: &'a [CmdOpCode], + ) -> Result { + Ok(Self { + pldm_type, + protocol_version: match PldmVersion::try_from(protocol_version) { + Ok(ver) => ver.bcd_encode_to_ver32(), + Err(_) => return Err(PldmError::InvalidProtocolVersion), + }, + supported_commands, + }) + } +} + +/// `ControlContext` is a structure that holds the control context for the PLDM library. +/// +/// # Fields +/// +/// * `tid` - An atomic unsigned size integer representing the transaction ID. +/// * `capabilities` - A reference to a slice of `ProtocolCapability` which represents the protocol capabilities. +pub struct ControlContext<'a> { + tid: AtomicUsize, + capabilities: &'a [ProtocolCapability<'a>], +} + +impl<'a> ControlContext<'a> { + pub fn new(capabilities: &'a [ProtocolCapability<'a>]) -> Self { + Self { + tid: AtomicUsize::new(UNASSIGNED_TID as usize), + capabilities, + } + } + + pub fn get_tid(&self) -> Tid { + self.tid.load(Ordering::SeqCst) as Tid + } + + pub fn set_tid(&self, tid: Tid) { + self.tid.store(tid as usize, Ordering::SeqCst); + } + + pub fn get_supported_commands( + &self, + pldm_type: PldmSupportedType, + protocol_version: Ver32, + ) -> Option<&[CmdOpCode]> { + self.capabilities + .iter() + .find(|cap| cap.pldm_type == pldm_type && cap.protocol_version == protocol_version) + .map(|cap| cap.supported_commands) + } + + pub fn get_protocol_versions( + &self, + pldm_type: PldmSupportedType, + versions: &mut [Ver32], + ) -> usize { + let mut count = 0; + for cap in self + .capabilities + .iter() + .filter(|cap| cap.pldm_type == pldm_type) + { + if count < versions.len() { + versions[count] = cap.protocol_version; + count += 1; + } else { + break; + } + } + count + } + + pub fn get_supported_types(&self, types: &mut [u8]) -> usize { + let mut count = 0; + for cap in self.capabilities.iter() { + let pldm_type = cap.pldm_type as u8; + if !types[..count].contains(&pldm_type) { + if count < types.len() { + types[count] = pldm_type; + count += 1; + } else { + break; + } + } + } + count + } + + pub fn is_supported_type(&self, pldm_type: PldmSupportedType) -> bool { + self.capabilities + .iter() + .any(|cap| cap.pldm_type == pldm_type) + } + + pub fn is_supported_version( + &self, + pldm_type: PldmSupportedType, + protocol_version: Ver32, + ) -> bool { + self.capabilities + .iter() + .any(|cap| cap.pldm_type == pldm_type && cap.protocol_version == protocol_version) + } + + pub fn is_supported_command(&self, pldm_type: PldmSupportedType, cmd: u8) -> bool { + self.capabilities + .iter() + .find(|cap| cap.pldm_type == pldm_type) + .is_some_and(|cap| cap.supported_commands.contains(&cmd)) + } +} + +/// Trait representing a responder for control commands in the PLDM protocol. +/// Implementors of this trait are responsible for handling various control commands +/// and generating appropriate responses. +/// +/// # Methods +/// +/// - `get_tid_rsp`: Generates a response for the "Get TID" command. +/// - `set_tid_rsp`: Generates a response for the "Set TID" command. +/// - `get_pldm_types_rsp`: Generates a response for the "Get PLDM Types" command. +/// - `get_pldm_commands_rsp`: Generates a response for the "Get PLDM Commands" command. +/// - `get_pldm_version_rsp`: Generates a response for the "Get PLDM Version" command. +/// +/// Each method takes a mutable reference to a payload buffer and returns a `Result` +/// containing the size of the response or a `MsgHandlerError` if an error occurs. +pub trait CtrlCmdResponder { + fn get_tid_rsp(&self, payload: &mut [u8]) -> Result; + fn set_tid_rsp(&self, payload: &mut [u8]) -> Result; + fn get_pldm_types_rsp(&self, payload: &mut [u8]) -> Result; + fn get_pldm_commands_rsp(&self, payload: &mut [u8]) -> Result; + fn get_pldm_version_rsp(&self, payload: &mut [u8]) -> Result; +} + +impl CtrlCmdResponder for ControlContext<'_> { + fn get_tid_rsp(&self, payload: &mut [u8]) -> Result { + let req = GetTidRequest::decode(payload).map_err(MsgHandlerError::Codec)?; + let resp = GetTidResponse::new( + req.hdr.instance_id(), + self.get_tid(), + PldmBaseCompletionCode::Success as u8, + ); + resp.encode(payload).map_err(MsgHandlerError::Codec) + } + + fn set_tid_rsp(&self, payload: &mut [u8]) -> Result { + let req = SetTidRequest::decode(payload).map_err(MsgHandlerError::Codec)?; + self.set_tid(req.tid); + let resp = + SetTidResponse::new(req.hdr.instance_id(), PldmBaseCompletionCode::Success as u8); + resp.encode(payload).map_err(MsgHandlerError::Codec) + } + + fn get_pldm_types_rsp(&self, payload: &mut [u8]) -> Result { + let req = GetPldmTypeRequest::decode(payload).map_err(MsgHandlerError::Codec)?; + let mut types = [0x0u8; 6]; + let num_types = self.get_supported_types(&mut types); + let resp = GetPldmTypeResponse::new( + req.hdr.instance_id(), + PldmBaseCompletionCode::Success as u8, + &types[..num_types], + ); + resp.encode(payload).map_err(MsgHandlerError::Codec) + } + + fn get_pldm_commands_rsp(&self, payload: &mut [u8]) -> Result { + let req = match GetPldmCommandsRequest::decode(payload) { + Ok(req) => req, + Err(_) => { + return generate_failure_response( + payload, + PldmBaseCompletionCode::InvalidLength as u8, + ) + } + }; + + let pldm_type_in_req = match PldmSupportedType::try_from(req.pldm_type) { + Ok(pldm_type) => pldm_type, + Err(_) => { + return generate_failure_response( + payload, + PldmControlCompletionCode::InvalidPldmTypeInRequestData as u8, + ) + } + }; + + if !self.is_supported_type(pldm_type_in_req) { + return generate_failure_response( + payload, + PldmControlCompletionCode::InvalidPldmTypeInRequestData as u8, + ); + } + + let version_in_req = req.protocol_version; + if !self.is_supported_version(pldm_type_in_req, version_in_req) { + return generate_failure_response( + payload, + PldmControlCompletionCode::InvalidPldmVersionInRequestData as u8, + ); + } + + let cmds = self + .get_supported_commands(pldm_type_in_req, version_in_req) + .unwrap(); + + let resp = GetPldmCommandsResponse::new( + req.hdr.instance_id(), + PldmBaseCompletionCode::Success as u8, + cmds, + ); + + match resp.encode(payload) { + Ok(bytes) => Ok(bytes), + Err(_) => { + generate_failure_response(payload, PldmBaseCompletionCode::InvalidLength as u8) + } + } + } + + fn get_pldm_version_rsp(&self, payload: &mut [u8]) -> Result { + let req = match GetPldmVersionRequest::decode(payload) { + Ok(req) => req, + Err(_) => { + return generate_failure_response( + payload, + PldmBaseCompletionCode::InvalidLength as u8, + ) + } + }; + + let pldm_type_in_req = match PldmSupportedType::try_from(req.pldm_type) { + Ok(pldm_type) => pldm_type, + Err(_) => { + return generate_failure_response( + payload, + PldmControlCompletionCode::InvalidPldmTypeInRequestData as u8, + ) + } + }; + + if !self.is_supported_type(pldm_type_in_req) { + return generate_failure_response( + payload, + PldmControlCompletionCode::InvalidPldmTypeInRequestData as u8, + ); + } + + if req.transfer_op_flag != TransferOperationFlag::GetFirstPart as u8 { + return generate_failure_response( + payload, + PldmControlCompletionCode::InvalidTransferOperationFlag as u8, + ); + } + + let mut versions = [0u32; 2]; + if self.get_protocol_versions(pldm_type_in_req, &mut versions) == 0 { + return generate_failure_response(payload, PldmBaseCompletionCode::Error as u8); + } + + // Only one version is supported for now + let resp = GetPldmVersionResponse { + hdr: req.hdr.into_response(), + completion_code: PldmBaseCompletionCode::Success as u8, + next_transfer_handle: 0, + transfer_rsp_flag: TransferRespFlag::StartAndEnd as u8, + version_data: versions[0], + }; + + match resp.encode(payload) { + Ok(bytes) => Ok(bytes), + Err(_) => { + generate_failure_response(payload, PldmBaseCompletionCode::InvalidLength as u8) + } + } + } +} \ No newline at end of file diff --git a/src/error.rs b/src/error.rs index 17dd6a2..f58428a 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,36 +1,17 @@ // Licensed under the Apache-2.0 license +use crate::firmware_device::fd_ops::FdOpsError; +use crate::transport::TransportError; +use pldm_common::codec::PldmCodecError; +use pldm_common::error::{PldmError, UtilError}; -#[derive(Debug, Clone, PartialEq)] -pub enum PldmError { - InvalidData, - InvalidLength, - InvalidMsgType, - InvalidProtocolVersion, - UnsupportedCmd, - UnsupportedPldmType, - InvalidCompletionCode, - InvalidTransferOpFlag, - InvalidTransferRespFlag, - - InvalidVersionStringType, - InvalidVersionStringLength, - InvalidFdState, - InvalidDescriptorType, - InvalidDescriptorLength, - InvalidDescriptorCount, - InvalidComponentClassification, - InvalidComponentResponseCode, - InvalidComponentCompatibilityResponse, - InvalidComponentCompatibilityResponseCode, - InvalidTransferResult, - InvalidVerifyResult, - InvalidApplyResult, - InvalidGetStatusReasonCode, - InvalidAuxStateStatus, -} - -#[derive(Debug, Clone, PartialEq)] -pub enum UtilError { - InvalidMctpPayloadLength, - InvalidMctpMsgType, +/// Handle non-protocol specific error conditions. +#[derive(Debug)] +pub enum MsgHandlerError { + Codec(PldmCodecError), + Transport(TransportError), + PldmCommon(PldmError), + Util(UtilError), + FdOps(FdOpsError), + FdInitiatorModeError, + NotReady, } diff --git a/src/firmware_device/fd_context.rs b/src/firmware_device/fd_context.rs new file mode 100644 index 0000000..82eec88 --- /dev/null +++ b/src/firmware_device/fd_context.rs @@ -0,0 +1,1046 @@ +// Licensed under the Apache-2.0 license + +use crate::cmd_interface::generate_failure_response; +use crate::error::MsgHandlerError; +use crate::firmware_device::fd_internal::{FdInternal, FdReqState}; +use crate::firmware_device::fd_ops::{ComponentOperation, FdOps}; +use pldm_common::codec::PldmCodec; +use pldm_common::message::firmware_update::activate_fw::{ + ActivateFirmwareRequest, ActivateFirmwareResponse, +}; +use pldm_common::message::firmware_update::get_fw_params::{ + FirmwareParameters, GetFirmwareParametersRequest, GetFirmwareParametersResponse, +}; +use pldm_common::message::firmware_update::get_status::ProgressPercent; +use pldm_common::message::firmware_update::pass_component::{ + PassComponentTableRequest, PassComponentTableResponse, +}; +use pldm_common::message::firmware_update::query_devid::{ + QueryDeviceIdentifiersRequest, QueryDeviceIdentifiersResponse, +}; +use pldm_common::message::firmware_update::request_cancel::{ + CancelUpdateComponentRequest, CancelUpdateComponentResponse, CancelUpdateRequest, + CancelUpdateResponse, +}; +use pldm_common::message::firmware_update::request_update::{ + RequestUpdateRequest, RequestUpdateResponse, +}; +use pldm_common::message::firmware_update::transfer_complete::{ + TransferCompleteRequest, TransferResult, +}; +use pldm_common::message::firmware_update::update_component::{ + UpdateComponentRequest, UpdateComponentResponse, +}; + +use pldm_common::codec::PldmCodecError; +use pldm_common::message::firmware_update::apply_complete::{ApplyCompleteRequest, ApplyResult}; +use pldm_common::message::firmware_update::get_status::{ + AuxState, AuxStateStatus, GetStatusReasonCode, GetStatusRequest, GetStatusResponse, + UpdateOptionResp, +}; +use pldm_common::message::firmware_update::request_fw_data::{ + RequestFirmwareDataRequest, RequestFirmwareDataResponseFixed, +}; +use pldm_common::message::firmware_update::verify_complete::{VerifyCompleteRequest, VerifyResult}; +use pldm_common::protocol::base::{ + PldmBaseCompletionCode, PldmMsgHeader, PldmMsgType, TransferRespFlag, +}; +use pldm_common::protocol::firmware_update::{ + ComponentActivationMethods, ComponentCompatibilityResponse, ComponentCompatibilityResponseCode, + ComponentResponse, ComponentResponseCode, Descriptor, FirmwareDeviceState, FwUpdateCmd, + FwUpdateCompletionCode, PldmFirmwareString, UpdateOptionFlags, MAX_DESCRIPTORS_COUNT, + PLDM_FWUP_BASELINE_TRANSFER_SIZE, +}; +use pldm_common::util::fw_component::FirmwareComponent; + +use crate::firmware_device::fd_internal::{ + ApplyState, DownloadState, InitiatorModeState, VerifyState, +}; + +pub struct FirmwareDeviceContext { + ops: FdOps, + internal: FdInternal, +} + +impl FirmwareDeviceContext { + #[allow(clippy::new_without_default)] + pub fn new() -> Self { + Self { + ops: FdOps::new(), + internal: FdInternal::new(0,0,0) + } + } + + pub fn query_devid_rsp(&self, payload: &mut [u8]) -> Result { + // Decode the request message + let req = QueryDeviceIdentifiersRequest::decode(payload).map_err(MsgHandlerError::Codec)?; + + let mut device_identifiers: [Descriptor; MAX_DESCRIPTORS_COUNT] = + [Descriptor::default(); MAX_DESCRIPTORS_COUNT]; + + // Get the device identifiers + let descriptor_cnt = self + .ops + .get_device_identifiers(&mut device_identifiers) //CAD Driver Call + .map_err(MsgHandlerError::FdOps)?; + + // Create the response message + let resp = QueryDeviceIdentifiersResponse::new( + req.hdr.instance_id(), + PldmBaseCompletionCode::Success as u8, + &device_identifiers[0], + device_identifiers.get(1..descriptor_cnt), + ) + .map_err(MsgHandlerError::PldmCommon)?; + + match resp.encode(payload) { + Ok(bytes) => Ok(bytes), + Err(_) => { + generate_failure_response(payload, PldmBaseCompletionCode::InvalidLength as u8) + } + } + } + + pub fn get_firmware_parameters_rsp( + &mut self, + payload: &mut [u8], + ) -> Result { + // Decode the request message + let req = GetFirmwareParametersRequest::decode(payload).map_err(MsgHandlerError::Codec)?; + + let mut firmware_params = FirmwareParameters::default(); + self.ops + .get_firmware_parms(&mut firmware_params) + .map_err(MsgHandlerError::FdOps)?; + + // Construct response + let resp = GetFirmwareParametersResponse::new( + req.hdr.instance_id(), + PldmBaseCompletionCode::Success as u8, + &firmware_params, + ); + + match resp.encode(payload) { + Ok(bytes) => Ok(bytes), + Err(_) => { + generate_failure_response(payload, PldmBaseCompletionCode::InvalidLength as u8) + } + } + } + + pub fn request_update_rsp(&mut self, payload: &mut [u8]) -> Result { + // Check if FD is in idle state. Otherwise returns 'ALREADY_IN_UPDATE_MODE' completion code + if self.internal.is_update_mode() { + return generate_failure_response( + payload, + FwUpdateCompletionCode::AlreadyInUpdateMode as u8, + ); + } + + // Set timestamp for FD T1 timeout + self.set_fd_t1_ts(); + + // Decode the request message + let req = RequestUpdateRequest::decode(payload).map_err(MsgHandlerError::Codec)?; + let ua_transfer_size = req.fixed.max_transfer_size as usize; + if ua_transfer_size < PLDM_FWUP_BASELINE_TRANSFER_SIZE { + return generate_failure_response( + payload, + FwUpdateCompletionCode::InvalidTransferLength as u8, + ); + } + + // Get the transfer size for the firmware update operation + let fd_transfer_size = self + .ops + .get_xfer_size(ua_transfer_size) + .map_err(MsgHandlerError::FdOps)?; + + // Set transfer size to the internal state + self.internal.set_xfer_size(fd_transfer_size); + + // Construct response, no metadata or package data. + let resp = RequestUpdateResponse::new( + req.fixed.hdr.instance_id(), + PldmBaseCompletionCode::Success as u8, + 0, + 0, + None, + ); + + match resp.encode(payload) { + Ok(bytes) => { + // Move FD state to 'LearnComponents' + self.internal + .set_fd_state(FirmwareDeviceState::LearnComponents); + Ok(bytes) + } + Err(_) => { + generate_failure_response(payload, PldmBaseCompletionCode::InvalidLength as u8) + } + } + } + + pub fn pass_component_rsp(&mut self, payload: &mut [u8]) -> Result { + // Check if FD is in 'LearnComponents' state. Otherwise returns 'INVALID_STATE' completion code + if self.internal.get_fd_state() != FirmwareDeviceState::LearnComponents { + return generate_failure_response( + payload, + FwUpdateCompletionCode::InvalidStateForCommand as u8, + ); + } + + // Set timestamp for FD T1 timeout + self.set_fd_t1_ts(); + + // Decode the request message + let req = PassComponentTableRequest::decode(payload).map_err(MsgHandlerError::Codec)?; + let transfer_flag = match TransferRespFlag::try_from(req.fixed.transfer_flag) { + Ok(flag) => flag, + Err(_) => { + return generate_failure_response( + payload, + PldmBaseCompletionCode::InvalidData as u8, + ) + } + }; + + // Construct temporary storage for the component + let pass_comp = FirmwareComponent::new( + req.fixed.comp_classification, + req.fixed.comp_identifier, + req.fixed.comp_classification_index, + req.fixed.comp_comparison_stamp, + PldmFirmwareString { + str_type: req.fixed.comp_ver_str_type, + str_len: req.fixed.comp_ver_str_len, + str_data: req.comp_ver_str, + }, + None, + None, + ); + + let mut firmware_params = FirmwareParameters::default(); + self.ops + .get_firmware_parms(&mut firmware_params) + .map_err(MsgHandlerError::FdOps)?; + + let comp_resp_code = self + .ops + .handle_component( + &pass_comp, + &firmware_params, + ComponentOperation::PassComponent, + ) + .map_err(MsgHandlerError::FdOps)?; + + // Construct response + let resp = PassComponentTableResponse::new( + req.fixed.hdr.instance_id(), + PldmBaseCompletionCode::Success as u8, + if comp_resp_code == ComponentResponseCode::CompCanBeUpdated { + ComponentResponse::CompCanBeUpdated + } else { + ComponentResponse::CompCannotBeUpdated + }, + comp_resp_code, + ); + + match resp.encode(payload) { + Ok(bytes) => { + // Move FD state to 'ReadyTransfer' when the last component is passed + if transfer_flag == TransferRespFlag::End + || transfer_flag == TransferRespFlag::StartAndEnd + { + self.internal + .set_fd_state(FirmwareDeviceState::ReadyXfer); + } + Ok(bytes) + } + Err(_) => { + generate_failure_response(payload, PldmBaseCompletionCode::InvalidLength as u8) + } + } + } + + pub fn update_component_rsp(&mut self, payload: &mut [u8]) -> Result { + // Check if FD is in 'ReadyTransfer' state. Otherwise returns 'INVALID_STATE' completion code + if self.internal.get_fd_state() != FirmwareDeviceState::ReadyXfer { + return generate_failure_response( + payload, + FwUpdateCompletionCode::InvalidStateForCommand as u8, + ); + } + + // Set timestamp for FD T1 timeout + self.set_fd_t1_ts(); + + // Decode the request message + let req = UpdateComponentRequest::decode(payload).map_err(MsgHandlerError::Codec)?; + + // Construct temporary storage for the component + let update_comp = FirmwareComponent::new( + req.fixed.comp_classification, + req.fixed.comp_identifier, + req.fixed.comp_classification_index, + req.fixed.comp_comparison_stamp, + PldmFirmwareString { + str_type: req.fixed.comp_ver_str_type, + str_len: req.fixed.comp_ver_str_len, + str_data: req.comp_ver_str, + }, + Some(req.fixed.comp_image_size), + Some(UpdateOptionFlags(req.fixed.update_option_flags)), + ); + + // Store the component info into the internal state. + self.internal.set_component(&update_comp); + + // Adjust the update flags based on the device's capabilities if needed. Currently, the flags are set as received from the UA. + self.internal + .set_update_flags(UpdateOptionFlags(req.fixed.update_option_flags)); + + let mut firmware_params = FirmwareParameters::default(); + self.ops + .get_firmware_parms(&mut firmware_params) + .map_err(MsgHandlerError::FdOps)?; + + let comp_resp_code = self + .ops + .handle_component( + &update_comp, + &firmware_params, + ComponentOperation::UpdateComponent, /* This indicates this is an update request */ + ) + .map_err(MsgHandlerError::FdOps)?; + + // Construct response + let resp = UpdateComponentResponse::new( + req.fixed.hdr.instance_id(), + PldmBaseCompletionCode::Success as u8, + if comp_resp_code == ComponentResponseCode::CompCanBeUpdated { + ComponentCompatibilityResponse::CompCanBeUpdated + } else { + ComponentCompatibilityResponse::CompCannotBeUpdated + }, + ComponentCompatibilityResponseCode::try_from(comp_resp_code as u8).unwrap(), + UpdateOptionFlags(req.fixed.update_option_flags), + 0, + None, + ); + + match resp.encode(payload) { + Ok(bytes) => { + if comp_resp_code == ComponentResponseCode::CompCanBeUpdated { + self.internal + .set_initiator_mode(InitiatorModeState::Download(DownloadState::default())); + // Set up the req for download. + self.internal + .set_fd_req(FdReqState::Ready, false, None, None, None, None); + + // Move FD state machine to download state. + self.internal + .set_fd_state(FirmwareDeviceState::Download); + } + Ok(bytes) + } + Err(_) => { + generate_failure_response(payload, PldmBaseCompletionCode::InvalidLength as u8) + } + } + } + + pub fn activate_firmware_rsp( + &mut self, + payload: &mut [u8], + ) -> Result { + // Check if FD is in 'ReadyTransfer' state. Otherwise returns 'INVALID_STATE' completion code + if self.internal.get_fd_state() != FirmwareDeviceState::ReadyXfer { + return generate_failure_response( + payload, + FwUpdateCompletionCode::InvalidStateForCommand as u8, + ); + } + + // Decode the request message + let req = ActivateFirmwareRequest::decode(payload).map_err(MsgHandlerError::Codec)?; + let self_contained = req.self_contained_activation_req; + + // Validate self_contained value + match self_contained { + 0 | 1 => {} + _ => { + return generate_failure_response( + payload, + PldmBaseCompletionCode::InvalidData as u8, + ) + } + } + + let mut estimated_time = 0u16; + let completion_code = self + .ops + .activate(self_contained, &mut estimated_time) + .map_err(MsgHandlerError::FdOps)?; + + // Construct response + let resp = + ActivateFirmwareResponse::new(req.hdr.instance_id(), completion_code, estimated_time); + + match resp.encode(payload) { + Ok(bytes) => { + if completion_code == PldmBaseCompletionCode::Success as u8 + || completion_code == FwUpdateCompletionCode::ActivationNotRequired as u8 + { + self.internal + .set_fd_state(FirmwareDeviceState::Activate); + self.internal + .set_fd_idle(GetStatusReasonCode::ActivateFw); + } + Ok(bytes) + } + Err(_) => { + generate_failure_response(payload, PldmBaseCompletionCode::InvalidLength as u8) + } + } + } + + pub fn cancel_update_component_rsp( + &mut self, + payload: &mut [u8], + ) -> Result { + // If FD is not in update mode, return 'NOT_IN_UPDATE_MODE' completion code + if !self.internal.is_update_mode() { + return generate_failure_response( + payload, + FwUpdateCompletionCode::NotInUpdateMode as u8, + ); + } + + let fd_state = self.internal.get_fd_state(); + let should_cancel = match fd_state { + FirmwareDeviceState::Download | FirmwareDeviceState::Verify => true, + FirmwareDeviceState::Apply => { + // In apply state, only cancel if not completed successfully + !(self.internal.is_fd_req_complete() + && self.internal.get_fd_req_result() + == Some(ApplyResult::ApplySuccess as u8)) + } + _ => { + return generate_failure_response( + payload, + FwUpdateCompletionCode::InvalidStateForCommand as u8, + ); + } + }; + + if should_cancel { + self.ops + .cancel_update_component(&self.internal.get_component()) + .map_err(MsgHandlerError::FdOps)?; + } + + // Decode the request message + let req = CancelUpdateComponentRequest::decode(payload).map_err(MsgHandlerError::Codec)?; + let completion_code = if should_cancel { + PldmBaseCompletionCode::Success as u8 + } else { + PldmBaseCompletionCode::Error as u8 + }; + + let resp = CancelUpdateComponentResponse::new(req.hdr.instance_id(), completion_code); + match resp.encode(payload) { + Ok(bytes) => { + if should_cancel { + // Set FD state to 'ReadyTransfer' + self.internal + .set_fd_state(FirmwareDeviceState::ReadyXfer); + } + Ok(bytes) + } + Err(_) => { + generate_failure_response(payload, PldmBaseCompletionCode::InvalidLength as u8) + } + } + } + + pub fn cancel_update_rsp(&mut self, payload: &mut [u8]) -> Result { + // If FD is not in update mode, return 'NOT_IN_UPDATE_MODE' completion code + if !self.internal.is_update_mode() { + return generate_failure_response( + payload, + FwUpdateCompletionCode::NotInUpdateMode as u8, + ); + } + + // Set timestamp for FD T1 timeout + self.set_fd_t1_ts(); + + let fd_state = self.internal.get_fd_state(); + let should_cancel = match fd_state { + FirmwareDeviceState::Download | FirmwareDeviceState::Verify => true, + FirmwareDeviceState::Apply => { + // In apply state, only cancel if not completed successfully + !(self.internal.is_fd_req_complete() + && self.internal.get_fd_req().result + == Some(ApplyResult::ApplySuccess as u8)) + } + _ => false, + }; + + if should_cancel { + self.ops + .cancel_update_component(&self.internal.get_component()) + .map_err(MsgHandlerError::FdOps)?; + } + + // Decode the request message + let req = CancelUpdateRequest::decode(payload).map_err(MsgHandlerError::Codec)?; + let completion_code = if should_cancel { + PldmBaseCompletionCode::Success as u8 + } else { + PldmBaseCompletionCode::Error as u8 + }; + + let (non_functioning_component_indication, non_functioning_component_bitmap) = self + .ops + .get_non_functional_component_info() + .map_err(MsgHandlerError::FdOps)?; + + let resp = CancelUpdateResponse::new( + req.hdr.instance_id(), + completion_code, + non_functioning_component_indication, + non_functioning_component_bitmap, + ); + + match resp.encode(payload) { + Ok(bytes) => { + if should_cancel { + // Set FD state to 'Idle' + self.internal + .set_fd_idle(GetStatusReasonCode::CancelUpdate); + } + Ok(bytes) + } + Err(_) => { + generate_failure_response(payload, PldmBaseCompletionCode::InvalidLength as u8) + } + } + } + + pub fn get_status_rsp(&mut self, payload: &mut [u8]) -> Result { + let req = GetStatusRequest::decode(payload).map_err(MsgHandlerError::Codec)?; + + let cur_state = self.internal.get_fd_state(); + let prev_state = self.internal.get_fd_prev_state(); + let (progress_percent, update_flags) = match cur_state { + FirmwareDeviceState::Download => { + let mut progress = ProgressPercent::default(); + let _ = self + .ops + .query_download_progress(&self.internal.get_component(), &mut progress); + let update_flags = self.internal.get_update_flags(); + (progress, update_flags) + } + FirmwareDeviceState::Verify => { + let progress = if let Some(percent) = self.internal.get_fd_verify_progress() { + ProgressPercent::new(percent).unwrap() + } else { + ProgressPercent::default() + }; + let update_flags = self.internal.get_update_flags(); + (progress, update_flags) + } + FirmwareDeviceState::Apply => { + let progress = if let Some(percent) = self.internal.get_fd_apply_progress() { + ProgressPercent::new(percent).unwrap() + } else { + ProgressPercent::default() + }; + let update_flags = self.internal.get_update_flags(); + (progress, update_flags) + } + _ => ( + ProgressPercent::default(), + self.internal.get_update_flags(), + ), + }; + + let (aux_state, aux_state_status) = match self.internal.get_fd_req_state() { + FdReqState::Unused => ( + AuxState::IdleLearnComponentsReadXfer, + AuxStateStatus::AuxStateInProgressOrSuccess as u8, + ), + FdReqState::Sent => ( + AuxState::OperationInProgress, + AuxStateStatus::AuxStateInProgressOrSuccess as u8, + ), + FdReqState::Ready => { + if self.internal.is_fd_req_complete() { + ( + AuxState::OperationSuccessful, + AuxStateStatus::AuxStateInProgressOrSuccess as u8, + ) + } else { + ( + AuxState::OperationInProgress, + AuxStateStatus::AuxStateInProgressOrSuccess as u8, + ) + } + } + FdReqState::Failed => { + let status = self + .internal + .get_fd_req_result() + .unwrap_or(AuxStateStatus::GenericError as u8); + (AuxState::OperationFailed, status) + } + }; + + let resp = GetStatusResponse::new( + req.hdr.instance_id(), + PldmBaseCompletionCode::Success as u8, + cur_state, + prev_state, + aux_state, + aux_state_status, + progress_percent, + self.internal + .get_fd_reason() + .unwrap_or(GetStatusReasonCode::Initialization), + if update_flags.request_force_update() { + UpdateOptionResp::ForceUpdate + } else { + UpdateOptionResp::NoForceUpdate + }, + ); + + match resp.encode(payload) { + Ok(bytes) => Ok(bytes), + Err(_) => { + generate_failure_response(payload, PldmBaseCompletionCode::InvalidLength as u8) + } + } + } + + pub fn set_fd_t1_ts(&mut self) { + self.internal + .set_fd_t1_update_ts(self.ops.now()); + } + + pub fn should_start_initiator_mode(&mut self) -> bool { + self.internal.get_fd_state() == FirmwareDeviceState::Download + } + + pub fn should_stop_initiator_mode(&mut self) -> bool { + !matches!( + self.internal.get_fd_state(), + FirmwareDeviceState::Download + | FirmwareDeviceState::Verify + | FirmwareDeviceState::Apply + ) + } + + pub fn fd_progress(&mut self, payload: &mut [u8]) -> Result { + let fd_state = self.internal.get_fd_state(); + + let result = match fd_state { + FirmwareDeviceState::Download => self.fd_progress_download(payload), + FirmwareDeviceState::Verify => self.pldm_fd_progress_verify(payload), + FirmwareDeviceState::Apply => self.pldm_fd_progress_apply(payload), + _ => Err(MsgHandlerError::FdInitiatorModeError), + }?; + + // If a response is not received within T1 in FD-driven states, cancel the update and transition to idle state. + if (fd_state == FirmwareDeviceState::Download + || fd_state == FirmwareDeviceState::Verify + || fd_state == FirmwareDeviceState::Apply) + && self.internal.get_fd_req_state() == FdReqState::Sent + && self.ops.now() - self.internal.get_fd_t1_update_ts() + > self.internal.get_fd_t1_timeout() + { + self.ops + .cancel_update_component(&self.internal.get_component()) + .map_err(MsgHandlerError::FdOps)?; + self.internal.fd_idle_timeout(); + return Ok(0); + } + + Ok(result) + } + + pub fn handle_response(&mut self, payload: &mut [u8]) -> Result<(), MsgHandlerError> { + let rsp_header = + PldmMsgHeader::<[u8; 3]>::decode(payload).map_err(MsgHandlerError::Codec)?; + let (cmd_code, instance_id) = (rsp_header.cmd_code(), rsp_header.instance_id()); + + let fd_req = self.internal.get_fd_req(); + if fd_req.state != FdReqState::Sent + || fd_req.instance_id != Some(instance_id) + || fd_req.command != Some(cmd_code) + { + // Unexpected response + return Err(MsgHandlerError::FdInitiatorModeError); + } + + self.set_fd_t1_ts(); + + match FwUpdateCmd::try_from(cmd_code) { + Ok(FwUpdateCmd::RequestFirmwareData) => self.process_request_fw_data_rsp(payload), + Ok(FwUpdateCmd::TransferComplete) => self.process_transfer_complete_rsp(payload), + Ok(FwUpdateCmd::VerifyComplete) => self.process_verify_complete_rsp(payload), + Ok(FwUpdateCmd::ApplyComplete) => self.process_apply_complete_rsp(payload), + _ => Err(MsgHandlerError::FdInitiatorModeError), + } + } + + fn process_request_fw_data_rsp(&mut self, payload: &mut [u8]) -> Result<(), MsgHandlerError> { + let fd_state = self.internal.get_fd_state(); + if fd_state != FirmwareDeviceState::Download { + return Err(MsgHandlerError::FdInitiatorModeError); + } + + let fd_req = self.internal.get_fd_req(); + if fd_req.complete { + // Received data after completion + return Err(MsgHandlerError::FdInitiatorModeError); + } + + // Decode the response message fixed + let fw_data_rsp_fixed: RequestFirmwareDataResponseFixed = + RequestFirmwareDataResponseFixed::decode(payload).map_err(MsgHandlerError::Codec)?; + + match fw_data_rsp_fixed.completion_code { + code if code == PldmBaseCompletionCode::Success as u8 => {} + code if code == FwUpdateCompletionCode::RetryRequestFwData as u8 => return Ok(()), + _ => { + self.internal + .set_fd_req( + FdReqState::Ready, + true, + Some(TransferResult::FdAbortedTransfer as u8), + None, + None, + None, + ); + return Ok(()); + } + } + + let (offset, length) = self.internal.get_fd_download_state().unwrap(); + + let fw_data = payload[core::mem::size_of::()..] + .get(..length as usize) + .ok_or(MsgHandlerError::Codec(PldmCodecError::BufferTooShort))?; + + let fw_component = &self.internal.get_component(); + let res = self + .ops + .download_fw_data(offset as usize, fw_data, fw_component) + .map_err(MsgHandlerError::FdOps)?; + + if res == TransferResult::TransferSuccess { + if self.ops.is_download_complete(fw_component) { + // Mark as complete, next progress() call will send the TransferComplete request + self.internal + .set_fd_req( + FdReqState::Ready, + true, + Some(TransferResult::TransferSuccess as u8), + None, + None, + None, + ); + } else { + // Invoke another request if there is more data to download + self.internal + .set_fd_req(FdReqState::Ready, false, None, None, None, None); + } + } else { + // Pass the callback error as the TransferResult + self.internal + .set_fd_req(FdReqState::Ready, true, Some(res as u8), None, None, None); + } + Ok(()) + } + + fn process_transfer_complete_rsp( + &mut self, + _payload: &mut [u8], + ) -> Result<(), MsgHandlerError> { + let fd_state = self.internal.get_fd_state(); + if fd_state != FirmwareDeviceState::Download { + return Err(MsgHandlerError::FdInitiatorModeError); + } + + let fd_req = self.internal.get_fd_req(); + if fd_req.state != FdReqState::Sent || !fd_req.complete { + return Err(MsgHandlerError::FdInitiatorModeError); + } + + /* Next state depends whether the transfer succeeded */ + if fd_req.result == Some(TransferResult::TransferSuccess as u8) { + // Switch to Verify + self.internal + .set_initiator_mode(InitiatorModeState::Verify(VerifyState::default())); + self.internal + .set_fd_req(FdReqState::Ready, false, None, None, None, None); + self.internal + .set_fd_state(FirmwareDeviceState::Verify); + } else { + // Wait for UA to cancel + self.internal + .set_fd_req(FdReqState::Failed, true, fd_req.result, None, None, None); + } + + Ok(()) + } + + fn process_verify_complete_rsp( + &mut self, + _payload: &mut [u8], + ) -> Result<(), MsgHandlerError> { + let fd_state = self.internal.get_fd_state(); + if fd_state != FirmwareDeviceState::Verify { + return Err(MsgHandlerError::FdInitiatorModeError); + } + + let fd_req = self.internal.get_fd_req(); + if fd_req.state != FdReqState::Sent || !fd_req.complete { + return Err(MsgHandlerError::FdInitiatorModeError); + } + + /* Next state depends whether the verify succeeded */ + if fd_req.result == Some(VerifyResult::VerifySuccess as u8) { + // Switch to Apply + self.internal + .set_initiator_mode(InitiatorModeState::Apply(ApplyState::default())); + self.internal + .set_fd_req(FdReqState::Ready, false, None, None, None, None); + self.internal.set_fd_state(FirmwareDeviceState::Apply); + } else { + // Wait for UA to cancel + self.internal + .set_fd_req(FdReqState::Failed, true, fd_req.result, None, None, None); + } + + Ok(()) + } + + fn process_apply_complete_rsp(&mut self, _payload: &mut [u8]) -> Result<(), MsgHandlerError> { + let fd_state = self.internal.get_fd_state(); + if fd_state != FirmwareDeviceState::Apply { + return Err(MsgHandlerError::FdInitiatorModeError); + } + + let fd_req = self.internal.get_fd_req(); + if fd_req.state != FdReqState::Sent || !fd_req.complete { + return Err(MsgHandlerError::FdInitiatorModeError); + } + + if fd_req.result == Some(ApplyResult::ApplySuccess as u8) { + // Switch to Xfer + self.internal + .set_fd_req(FdReqState::Unused, false, None, None, None, None); + self.internal + .set_fd_state(FirmwareDeviceState::ReadyXfer); + } else { + // Wait for UA to cancel + self.internal + .set_fd_req(FdReqState::Failed, true, fd_req.result, None, None, None); + } + + Ok(()) + } + + fn fd_progress_download(&mut self, payload: &mut [u8]) -> Result { + if !self.should_send_fd_request() { + return Err(MsgHandlerError::FdInitiatorModeError); + } + + let instance_id = self.internal.alloc_next_instance_id().unwrap(); + // If the request is complete, send TransferComplete + if self.internal.is_fd_req_complete() { + let result = self + .internal + .get_fd_req_result() + .ok_or(MsgHandlerError::FdInitiatorModeError)?; + + let msg_len = TransferCompleteRequest::new( + instance_id, + PldmMsgType::Request, + TransferResult::try_from(result).unwrap(), + ) + .encode(payload) + .map_err(MsgHandlerError::Codec)?; + + // Set fd req state to sent + let req_sent_timestamp = self.ops.now(); + self.internal + .set_fd_req( + FdReqState::Sent, + true, + Some(result), + Some(instance_id), + Some(FwUpdateCmd::TransferComplete as u8), + Some(req_sent_timestamp), + ); + + Ok(msg_len) + } else { + let (requested_offset, requested_length) = self + .ops + .query_download_offset_and_length(&self.internal.get_component()) + .map_err(MsgHandlerError::FdOps)?; + + if let Some((chunk_offset, chunk_length)) = self + .internal + .get_fd_download_chunk(requested_offset as u32, requested_length as u32) + { + let msg_len = RequestFirmwareDataRequest::new( + instance_id, + PldmMsgType::Request, + chunk_offset, + chunk_length, + ) + .encode(payload) + .map_err(MsgHandlerError::Codec)?; + + // Store offset and length into the internal state + self.internal + .set_fd_download_state(chunk_offset, chunk_length); + + // Set fd req state to sent + let req_sent_timestamp = self.ops.now(); + self.internal + .set_fd_req( + FdReqState::Sent, + false, + None, + Some(instance_id), + Some(FwUpdateCmd::RequestFirmwareData as u8), + Some(req_sent_timestamp), + ); + Ok(msg_len) + } else { + Err(MsgHandlerError::FdInitiatorModeError) + } + } + } + + fn pldm_fd_progress_verify(&mut self, _payload: &mut [u8]) -> Result { + if !self.should_send_fd_request() { + return Err(MsgHandlerError::FdInitiatorModeError); + } + + let mut res = VerifyResult::default(); + if !self.internal.is_fd_req_complete() { + let mut progress_percent = ProgressPercent::default(); + res = self + .ops + .verify(&self.internal.get_component(), &mut progress_percent) + .map_err(MsgHandlerError::FdOps)?; + + // Set the progress percent to VerifyState + self.internal + .set_fd_verify_progress(progress_percent.value()); + + if res == VerifyResult::VerifySuccess && progress_percent.value() < 100 { + // doing nothing and wait for the next call + return Ok(0); + } + } + + let instance_id = self.internal.alloc_next_instance_id().unwrap(); + let verify_complete_req = + VerifyCompleteRequest::new(instance_id, PldmMsgType::Request, res); + + // Encode the request message + let msg_len = verify_complete_req + .encode(_payload) + .map_err(MsgHandlerError::Codec)?; + + self.internal + .set_fd_req( + FdReqState::Sent, + true, + Some(res as u8), + Some(instance_id), + Some(FwUpdateCmd::VerifyComplete as u8), + Some(self.ops.now()), + ); + + Ok(msg_len) + } + + fn pldm_fd_progress_apply(&mut self, _payload: &mut [u8]) -> Result { + if !self.should_send_fd_request() { + return Err(MsgHandlerError::FdInitiatorModeError); + } + + let mut res = ApplyResult::default(); + if !self.internal.is_fd_req_complete() { + let mut progress_percent = ProgressPercent::default(); + res = self + .ops + .apply(&self.internal.get_component(), &mut progress_percent) + .map_err(MsgHandlerError::FdOps)?; + + // Set the progress percent to ApplyState + self.internal + .set_fd_apply_progress(progress_percent.value()); + + if res == ApplyResult::ApplySuccess && progress_percent.value() < 100 { + // doing nothing and wait for the next call + return Ok(0); + } + } + + // Allocate the next instance ID + let instance_id = self.internal.alloc_next_instance_id().unwrap(); + let apply_complete_req = ApplyCompleteRequest::new( + instance_id, + PldmMsgType::Request, + res, + ComponentActivationMethods(0), + ); + // Encode the request message + let msg_len = apply_complete_req + .encode(_payload) + .map_err(MsgHandlerError::Codec)?; + + self.internal + .set_fd_req( + FdReqState::Sent, + true, + Some(res as u8), + Some(instance_id), + Some(FwUpdateCmd::ApplyComplete as u8), + Some(self.ops.now()), + ); + + Ok(msg_len) + } + + fn should_send_fd_request(&self) -> bool { + let now = self.ops.now(); + + let fd_req_state = self.internal.get_fd_req_state(); + match fd_req_state { + FdReqState::Unused => false, + FdReqState::Ready => true, + FdReqState::Failed => false, + FdReqState::Sent => { + let fd_req_sent_time = self.internal.get_fd_sent_time().unwrap(); + if now < fd_req_sent_time { + // Time went backwards + return false; + } + + // Send if retry time has elapsed + return (now - fd_req_sent_time) >= self.internal.get_fd_t2_retry_time(); + } + } + } +} diff --git a/src/firmware_device/fd_internal.rs b/src/firmware_device/fd_internal.rs new file mode 100644 index 0000000..bf60f28 --- /dev/null +++ b/src/firmware_device/fd_internal.rs @@ -0,0 +1,364 @@ +// Licensed under the Apache-2.0 license + +use crate::control_context::Tid; +use pldm_common::message::firmware_update::get_status::GetStatusReasonCode; +use pldm_common::protocol::firmware_update::{ + FirmwareDeviceState, PldmFdTime, UpdateOptionFlags, PLDM_FWUP_MAX_PADDING_SIZE, +}; +use pldm_common::util::fw_component::FirmwareComponent; + +pub struct FdInternal { + // Current state of the firmware device. + state: FirmwareDeviceState, + + // Previous state of the firmware device. + prev_state: FirmwareDeviceState, + + // Reason for the last transition to the idle state. + // Only valid when `state == FirmwareDeviceState::Idle`. + reason: Option, + + // Details of the component currently being updated. + // Set by `UpdateComponent`, available during download/verify/apply. + update_comp: FirmwareComponent, + + // Flags indicating update options. + update_flags: UpdateOptionFlags, + + // Maximum transfer size allowed by the UA or platform implementation. + max_xfer_size: u32, + + // Request details used for download/verify/apply operations. + req: FdReq, + + // Mode-specific data for the requester. + initiator_mode_state: InitiatorModeState, + + // Address of the Update Agent (UA). + _ua_address: Option, + + // Timestamp for FD T1 timeout in milliseconds. + fd_t1_update_ts: PldmFdTime, + + fd_t1_timeout: PldmFdTime, + fd_t2_retry_time: PldmFdTime, +} + +impl Default for FdInternal { + fn default() -> Self { + Self::new( + crate::config::FD_MAX_XFER_SIZE as u32, + crate::config::DEFAULT_FD_T1_TIMEOUT, + crate::config::DEFAULT_FD_T2_RETRY_TIME, + ) + } +} + +impl FdInternal { + pub fn new(max_xfer_size: u32, fd_t1_timeout: u64, fd_t2_retry_time: u64) -> Self { + Self { + state: FirmwareDeviceState::Idle, + prev_state: FirmwareDeviceState::Idle, + reason: None, + update_comp: FirmwareComponent::default(), + update_flags: UpdateOptionFlags(0), + max_xfer_size, + req: FdReq::new(), + initiator_mode_state: InitiatorModeState::Download(DownloadState::default()), + _ua_address: None, + fd_t1_update_ts: 0, + fd_t1_timeout, + fd_t2_retry_time, + } + } + + pub fn is_update_mode(&self) -> bool { + self.state != FirmwareDeviceState::Idle + } + + pub fn set_fd_state(&mut self, state: FirmwareDeviceState) { + if self.state != state { + self.prev_state = self.state.clone(); + self.state = state; + } + } + + pub fn set_fd_idle(&mut self, reason_code: GetStatusReasonCode) { + if self.state != FirmwareDeviceState::Idle { + self.prev_state = self.state.clone(); + self.state = FirmwareDeviceState::Idle; + self.reason = Some(reason_code); + } + } + + pub fn fd_idle_timeout(&mut self) { + let state = self.get_fd_state(); + let reason = match state { + FirmwareDeviceState::Idle => return, + FirmwareDeviceState::LearnComponents => GetStatusReasonCode::LearnComponentTimeout, + FirmwareDeviceState::ReadyXfer => GetStatusReasonCode::ReadyXferTimeout, + FirmwareDeviceState::Download => GetStatusReasonCode::DownloadTimeout, + FirmwareDeviceState::Verify => GetStatusReasonCode::VerifyTimeout, + FirmwareDeviceState::Apply => GetStatusReasonCode::ApplyTimeout, + FirmwareDeviceState::Activate => GetStatusReasonCode::ActivateFw, + }; + + self.set_fd_idle(reason); + } + + pub fn get_fd_reason(&self) -> Option { + self.reason + } + + pub fn get_fd_state(&self) -> FirmwareDeviceState { + self.state.clone() + } + + pub fn get_fd_prev_state(&self) -> FirmwareDeviceState { + self.prev_state.clone() + } + + pub fn set_xfer_size(&mut self, transfer_size: usize) { + self.max_xfer_size = transfer_size as u32; + } + + pub fn get_xfer_size(&self) -> usize { + self.max_xfer_size as usize + } + + pub fn set_component(&mut self, comp: &FirmwareComponent) { + self.update_comp = comp.clone(); + } + + pub fn get_component(&self) -> FirmwareComponent { + self.update_comp.clone() + } + + pub fn set_update_flags(&mut self, flags: UpdateOptionFlags) { + self.update_flags = flags; + } + + pub fn get_update_flags(&self) -> UpdateOptionFlags { + self.update_flags + } + + pub fn set_fd_req( + &mut self, + req_state: FdReqState, + complete: bool, + result: Option, + instance_id: Option, + command: Option, + sent_time: Option, + ) { + self.req = FdReq { + state: req_state, + complete, + result, + instance_id, + command, + sent_time, + }; + } + + pub fn alloc_next_instance_id(&mut self) -> Option { + self.req.instance_id = Some( + self + .req + .instance_id + .map_or(1, |id| (id + 1) % crate::config::INSTANCE_ID_COUNT), + ); + self.req.instance_id + } + + pub fn get_fd_req(&self) -> FdReq { + self.req.clone() + } + + pub fn get_fd_req_state(&self) -> FdReqState { + self.req.state.clone() + } + + pub fn set_fd_req_state(&mut self, state: FdReqState) { + self.req.state = state; + } + + pub fn get_fd_sent_time(&self) -> Option { + self.req.sent_time + } + + pub fn is_fd_req_complete(&self) -> bool { + self.req.complete + } + + pub fn get_fd_req_result(&self) -> Option { + self.req.result + } + + pub fn get_fd_download_chunk( + &self, + requested_offset: u32, + requested_length: u32, + ) -> Option<(u32, u32)> { + if self.state != FirmwareDeviceState::Download { + return None; + } + + let comp_image_size = self.update_comp.comp_image_size.unwrap_or(0); + if requested_offset > comp_image_size + || requested_offset + requested_length + > comp_image_size + PLDM_FWUP_MAX_PADDING_SIZE as u32 + { + return None; + } + let chunk_size = requested_length.min(self.max_xfer_size); + Some((requested_offset, chunk_size)) + } + + pub fn get_fd_download_state(&self) -> Option<(u32, u32)> { + if let InitiatorModeState::Download(download) = &self.initiator_mode_state { + Some((download.offset, download.length)) + } else { + None + } + } + + pub fn set_fd_download_state(&mut self, offset: u32, length: u32) { + if let InitiatorModeState::Download(download) = &mut self.initiator_mode_state { + download.offset = offset; + download.length = length; + } + } + + pub fn set_initiator_mode(&mut self, mode: InitiatorModeState) { + self.initiator_mode_state = mode; + } + + pub fn set_fd_verify_progress(&mut self, progress: u8) { + if let InitiatorModeState::Verify(verify) = &mut self.initiator_mode_state { + verify.progress_percent = progress; + } + } + + pub fn set_fd_apply_progress(&mut self, progress: u8) { + if let InitiatorModeState::Apply(apply) = &mut self.initiator_mode_state { + apply.progress_percent = progress; + } + } + + pub fn get_fd_verify_progress(&mut self) -> Option { + if let InitiatorModeState::Verify(verify) = &mut self.initiator_mode_state { + Some(verify.progress_percent) + } else { + None + } + } + + pub fn get_fd_apply_progress(&self) -> Option { + if let InitiatorModeState::Apply(apply) = &self.initiator_mode_state { + Some(apply.progress_percent) + } else { + None + } + } + + pub fn set_fd_t1_update_ts(&mut self, timestamp: PldmFdTime) { + self.fd_t1_update_ts = timestamp; + } + + pub fn get_fd_t1_update_ts(&self) -> PldmFdTime { + self.fd_t1_update_ts + } + + pub fn set_fd_t1_timeout(&mut self, timeout: PldmFdTime) { + self.fd_t1_timeout = timeout; + } + + pub fn get_fd_t1_timeout(&self) -> PldmFdTime { + self.fd_t1_timeout + } + + pub fn set_fd_t2_retry_time(&mut self, retry_time: PldmFdTime) { + self.fd_t2_retry_time = retry_time; + } + + pub fn get_fd_t2_retry_time(&self) -> PldmFdTime { + self.fd_t2_retry_time + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum FdReqState { + // The `pldm_fd_req` instance is unused. + Unused, + // Ready to send a request. + Ready, + // Waiting for a response. + Sent, + // Completed and failed; will not send more requests. + Failed, +} + +#[derive(Debug, Clone)] +pub struct FdReq { + // The current state of the request. + pub state: FdReqState, + + // Indicates if the request is complete and ready to transition to the next state. + // This is relevant for TransferComplete, VerifyComplete, and ApplyComplete requests. + pub complete: bool, + + // The result of the request, only valid when `complete` is set. + pub result: Option, + + // The instance ID of the request, only valid in the `SENT` state. + pub instance_id: Option, + + // The command associated with the request, only valid in the `SENT` state. + pub command: Option, + + // The time when the request was sent, only valid in the `SENT` state. + pub sent_time: Option, +} + +impl Default for FdReq { + fn default() -> Self { + Self::new() + } +} + +impl FdReq { + fn new() -> Self { + Self { + state: FdReqState::Unused, + complete: false, + result: None, + instance_id: None, + command: None, + sent_time: None, + } + } +} + +#[derive(Debug)] +pub enum InitiatorModeState { + Download(DownloadState), + Verify(VerifyState), + Apply(ApplyState), +} + +#[derive(Debug, Default)] +pub struct DownloadState { + pub offset: u32, + pub length: u32, +} + +#[derive(Debug, Default)] +pub struct VerifyState { + pub progress_percent: u8, +} + +#[derive(Debug, Default)] +pub struct ApplyState { + pub progress_percent: u8, +} diff --git a/src/firmware_device/fd_ops.rs b/src/firmware_device/fd_ops.rs new file mode 100644 index 0000000..f9f746d --- /dev/null +++ b/src/firmware_device/fd_ops.rs @@ -0,0 +1,270 @@ +// Licensed under the Apache-2.0 license + +use pldm_common::message::firmware_update::apply_complete::ApplyResult; +use pldm_common::message::firmware_update::get_status::ProgressPercent; +use pldm_common::message::firmware_update::request_cancel::{ + NonFunctioningComponentBitmap, NonFunctioningComponentIndication, +}; +use pldm_common::message::firmware_update::transfer_complete::TransferResult; +use pldm_common::message::firmware_update::verify_complete::VerifyResult; +use pldm_common::util::fw_component::FirmwareComponent; +use pldm_common::{ + message::firmware_update::get_fw_params::FirmwareParameters, + protocol::firmware_update::{ComponentResponseCode, Descriptor, PldmFdTime}, +}; + +#[derive(Debug)] +pub enum FdOpsError { + DeviceIdentifiersError, + FirmwareParametersError, + TransferSizeError, + ComponentError, + FwDownloadError, + VerifyError, + ApplyError, + ActivateError, + CancelUpdateError, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum ComponentOperation { + PassComponent, + UpdateComponent, +} + +pub struct FdOps { } + +/// Trait for firmware device-specific operations. +/// +/// This trait defines asynchronous methods for performing various firmware device operations, +/// including retrieving device identifiers, firmware parameters, and transfer sizes. It also +/// provides methods for handling firmware components, managing firmware data downloads, verifying +/// and applying firmware, activating new firmware, and obtaining the current timestamp. +impl FdOps { + + pub fn new() -> Self { + Self { } + } + + /// Asynchronously retrieves device identifiers. + /// + /// # Arguments + /// + /// * `device_identifiers` - A mutable slice of `Descriptor` to store the retrieved device identifiers. + /// + /// # Returns + /// + /// * `Result` - On success, returns the number of device identifiers retrieved. + /// On failure, returns an `FdOpsError`. + pub fn get_device_identifiers( + &self, + device_identifiers: &mut [Descriptor], + ) -> Result { Ok(0xff) } + + /// Asynchronously retrieves firmware parameters. + /// + /// # Arguments + /// + /// * `firmware_params` - A mutable reference to `FirmwareParameters` to store the retrieved firmware parameters. + /// + /// # Returns + /// + /// * `Result<(), FdOpsError>` - On success, returns `Ok(())`. On failure, returns an `FdOpsError`. + pub fn get_firmware_parms( + &self, + firmware_params: &mut FirmwareParameters, + ) -> Result<(), FdOpsError> { Result::Ok(()) } + + /// Retrieves the transfer size for the firmware update operation. + /// + /// # Arguments + /// + /// * `ua_transfer_size` - The requested transfer size in bytes. + /// + /// # Returns + /// + /// * `Result` - On success, returns the transfer size in bytes. + /// On failure, returns an `FdOpsError`. + pub fn get_xfer_size(&self, ua_transfer_size: usize) -> Result { Ok(0xff) } + + /// Handles firmware component operations such as passing or updating components. + /// + /// # Arguments + /// + /// * `component` - A reference to the `FirmwareComponent` to be processed. + /// * `fw_params` - A reference to the `FirmwareParameters` associated with the operation. + /// * `op` - The `ComponentOperation` to be performed (e.g., pass or update). + /// + /// # Returns + /// + /// * `Result` - On success, returns a `ComponentResponseCode`. + /// On failure, returns an `FdOpsError`. + pub fn handle_component( + &self, + component: &FirmwareComponent, + fw_params: &FirmwareParameters, + op: ComponentOperation, + ) -> Result { Ok(ComponentResponseCode::CompNotSupported) } + + /// Queries the download offset and length for a given firmware component. + /// + /// # Arguments + /// + /// * `component` - A reference to the `FirmwareComponent` for which the download offset and length are queried. + /// + /// # Returns + /// + /// * `Result<(usize, usize), FdOpsError>` - On success, returns a tuple containing the offset and length in bytes. + /// On failure, returns an `FdOpsError`. + pub fn query_download_offset_and_length( + &self, + component: &FirmwareComponent, + ) -> Result<(usize, usize), FdOpsError> { Ok((0xff, 0xff)) } + + /// Handles firmware data downloading operations. + /// + /// # Arguments + /// + /// * `offset` - The offset in bytes where the firmware data should be written or processed. + /// * `data` - A slice of bytes representing the firmware data to be handled. + /// * `component` - A reference to the `FirmwareComponent` associated with the firmware data. + /// + /// # Returns + /// + /// * `Result` - On success, returns a `TransferResult` indicating the outcome of the operation. + /// On failure, returns an `FdOpsError`. + pub fn download_fw_data( + &self, + offset: usize, + data: &[u8], + component: &FirmwareComponent, + ) -> Result { Ok(TransferResult::TransferTimeOut) } + + /// Checks if the firmware download for a given component is complete. + /// + /// # Arguments + /// + /// * `component` - A reference to the `FirmwareComponent` for which the download completion status is checked. + /// + /// # Returns + /// + /// * `bool` - Returns `true` if the download is complete, otherwise `false`. + pub fn is_download_complete(&self, component: &FirmwareComponent) -> bool { false } + + /// Queries the download progress for a given firmware component. + /// + /// # Arguments + /// + /// * `component` - A reference to the `FirmwareComponent` for which the download progress is queried. + /// * `progress_percent` - A mutable reference to `ProgressPercent` to track the download progress. + /// + /// # Returns + /// + /// * `Result<(), FdOpsError>` - On success, returns `Ok(())`. On failure, returns an `FdOpsError`. + pub fn query_download_progress( + &self, + component: &FirmwareComponent, + progress_percent: &mut ProgressPercent, + ) -> Result<(), FdOpsError> { Ok(()) } + + /// Verifies the firmware component. + /// + /// # Arguments + /// + /// * `component` - A reference to the `FirmwareComponent` to be verified. + /// * `progress_percent` - A mutable reference to `ProgressPercent` to track the verification progress. + /// + /// # Returns + /// + /// * `Result` - On success, returns a `VerifyResult` indicating the outcome of the verification. + /// * On failure, returns an `FdOpsError`. + pub fn verify( + &self, + component: &FirmwareComponent, + progress_percent: &mut ProgressPercent, + ) -> Result { Ok(VerifyResult::VerifyGenericError) } + + /// Applies the firmware component. + /// + /// # Arguments + /// + /// * `component` - A reference to the `FirmwareComponent` to be applied. + /// * `progress_percent` - A mutable reference to `ProgressPercent` to track the application progress. + /// + /// # Returns + /// + /// * `Result` - On success, returns an `ApplyResult` indicating the outcome of the application. + /// * On failure, returns an `FdOpsError`. + pub fn apply( + &self, + component: &FirmwareComponent, + progress_percent: &mut ProgressPercent, + ) -> Result { Ok(ApplyResult::ApplyGenericError) } + + /// Activates new firmware. + /// + /// # Arguments + /// + /// * `self_contained_activation` - Indicates if self-contained activation is requested. + /// * `estimated_time` - A mutable reference to store the estimated time (in seconds) + /// required to perform self-activation. This may be left as `None` if not needed. + /// + /// # Returns + /// + /// * `Result` - On success, returns a PLDM completion code. + /// On failure, returns an `FdOpsError`. + /// + /// The device implementation is responsible for verifying that the expected components + /// have been updated. If not, it should return `PLDM_FWUP_INCOMPLETE_UPDATE`. + pub fn activate( + &self, + self_contained_activation: u8, + estimated_time: &mut u16, + ) -> Result { Ok(0xff) } + + /// Cancels the update operation for a specific firmware component. + /// + /// # Arguments + /// + /// * `component` - A reference to the `FirmwareComponent` for which the update operation should be canceled. + /// + /// # Returns + /// + /// * `Result<(), FdOpsError>` - On success, returns `Ok(())`. On failure, returns an `FdOpsError`. + pub fn cancel_update_component( + &self, + component: &FirmwareComponent, + ) -> Result<(), FdOpsError> { Ok(())} + + /// Indicates which components will be in a non-functioning state upon exiting update mode + /// due to cancel update request from UA. + /// + /// # Returns + /// + /// * `Result<(NonFunctioningComponentIndication, NonFunctioningComponentBitmap), FdOpsError>` - + /// On success, returns a tuple containing: + /// - `NonFunctioningComponentIndication`: Indicates whether components are functioning or not. + /// - `NonFunctioningComponentBitmap`: A bitmap representing non-functioning components. + /// On failure, returns an `FdOpsError`. + pub fn get_non_functional_component_info( + &self, + ) -> Result< + ( + NonFunctioningComponentIndication, + NonFunctioningComponentBitmap, + ), + FdOpsError, + > { + Ok(( + NonFunctioningComponentIndication::ComponentsFunctioning, + NonFunctioningComponentBitmap::new(0), + )) + } + + /// Retrieves the current timestamp in milliseconds. + /// + /// # Returns + /// + /// * `PldmFdTime` - The current timestamp in milliseconds. + pub fn now(&self) -> PldmFdTime { 0xbaddbadd } +} diff --git a/src/firmware_device/mod.rs b/src/firmware_device/mod.rs new file mode 100644 index 0000000..f1eb8b4 --- /dev/null +++ b/src/firmware_device/mod.rs @@ -0,0 +1,5 @@ +// Licensed under the Apache-2.0 license + +pub mod fd_context; +pub mod fd_internal; +pub mod fd_ops; diff --git a/src/lib.rs b/src/lib.rs index 660c84c..1ff2dc7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,17 @@ // Licensed under the Apache-2.0 license +#![no_std] -#![cfg_attr(target_arch = "riscv32", no_std)] +// Re-export core types for no_std compatibility +pub use core::{ + option::Option::{self, Some, None}, + result::Result::{self, Ok, Err}, +}; -pub mod codec; + +pub mod cmd_interface; +pub mod config; +pub mod control_context; pub mod error; -pub mod message; -pub mod protocol; -pub mod util; +pub mod firmware_device; +//pub mod timer; +pub mod transport; diff --git a/src/timer.rs b/src/timer.rs new file mode 100644 index 0000000..b6f1ad4 --- /dev/null +++ b/src/timer.rs @@ -0,0 +1,83 @@ +// Licensed under the Apache-2.0 license + +use embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex; +use embassy_sync::mutex::Mutex; +use libsyscall_caliptra::DefaultSyscalls; +use libtock_alarm::{Convert, Hz, Milliseconds}; +use libtock_platform::{self as platform}; +use libtock_platform::{DefaultConfig, ErrorCode, Syscalls}; +use libtockasync::TockSubscribe; + +pub struct AsyncAlarm( + S, + C, +); + +static ALARM_MUTEX: Mutex = Mutex::new(()); + +impl AsyncAlarm { + /// Run a check against the console capsule to ensure it is present. + #[inline(always)] + #[allow(dead_code)] + pub fn exists() -> Result<(), ErrorCode> { + S::command(DRIVER_NUM, command::EXISTS, 0, 0).to_result() + } + + pub fn get_frequency() -> Result { + S::command(DRIVER_NUM, command::FREQUENCY, 0, 0) + .to_result() + .map(Hz) + } + + #[allow(dead_code)] + pub fn get_ticks() -> Result { + S::command(DRIVER_NUM, command::TIME, 0, 0).to_result() + } + + pub fn get_milliseconds() -> Result { + let ticks = Self::get_ticks()? as u64; + let freq = (Self::get_frequency()?).0 as u64; + + Ok(ticks.saturating_div(freq / 1000)) + } + + pub async fn sleep_for(time: T) -> Result<(), ErrorCode> { + let freq = Self::get_frequency()?; + let ticks = time.to_ticks(freq).0; + let sub = TockSubscribe::subscribe::(DRIVER_NUM, 0); + S::command(DRIVER_NUM, command::SET_RELATIVE, ticks, 0) + .to_result() + .map(|_when: u32| ())?; + sub.await.map(|_| ()) + } + + pub async fn sleep(time: Milliseconds) { + // bad things happen if multiple tasks try to use the alarm at once + let guard = ALARM_MUTEX.lock().await; + let _ = AsyncAlarm::::sleep_for(time).await; + drop(guard); + } +} + +// ----------------------------------------------------------------------------- +// Driver number and command IDs +// ----------------------------------------------------------------------------- + +const DRIVER_NUM: u32 = 0; + +// Command IDs +#[allow(unused)] +mod command { + pub const EXISTS: u32 = 0; + pub const FREQUENCY: u32 = 1; + pub const TIME: u32 = 2; + pub const STOP: u32 = 3; + + pub const SET_RELATIVE: u32 = 5; + pub const SET_ABSOLUTE: u32 = 6; +} + +#[allow(unused)] +mod subscribe { + pub const CALLBACK: u32 = 0; +} diff --git a/src/transport.rs b/src/transport.rs new file mode 100644 index 0000000..1fdb4b0 --- /dev/null +++ b/src/transport.rs @@ -0,0 +1,48 @@ +// Licensed under the Apache-2.0 license +use pldm_common::util::mctp_transport::{ + MctpCommonHeader, MCTP_COMMON_HEADER_OFFSET, MCTP_PLDM_MSG_TYPE, +}; + +pub enum PldmTransportType { + Mctp, +} + +#[derive(Debug)] +pub enum TransportError { + DriverError, + BufferTooSmall, + UnexpectedMessageType, + ReceiveError, + SendError, + ResponseNotExpected, + NoRequestInFlight, +} + +pub struct MctpTransport; + +impl MctpTransport { + pub fn new() -> Self { + MctpTransport + } + + pub fn send_request(&mut self, _dest_eid: u8, _req: &[u8]) -> Result<(), TransportError> { + // TODO: Implement actual MCTP transport + Ok(()) + } + + pub fn receive_response(&mut self, _rsp: &mut [u8]) -> Result<(), TransportError> { + // TODO: Implement actual MCTP transport + Ok(()) + } + + pub fn receive_request(&mut self, _req: &mut [u8]) -> Result<(), TransportError> { + // TODO: Implement actual MCTP transport + Ok(()) + } + + pub fn send_response(&mut self, _resp: &[u8]) -> Result<(), TransportError> { + // TODO: Implement actual MCTP transport + Ok(()) + } + +} \ No newline at end of file