diff --git a/sonic_platform_base/module_base.py b/sonic_platform_base/module_base.py index bd333571a..438607647 100644 --- a/sonic_platform_base/module_base.py +++ b/sonic_platform_base/module_base.py @@ -13,6 +13,10 @@ import threading import contextlib import shutil +import time +from datetime import datetime, timezone +from swsscommon.swsscommon import SonicV2Connector # type: ignore + # PCI state database constants PCIE_DETACH_INFO_TABLE = "PCIE_DETACH_INFO" @@ -27,6 +31,7 @@ class ModuleBase(device_base.DeviceBase): # Device type definition. Note, this is a constant. DEVICE_TYPE = "module" PCI_OPERATION_LOCK_FILE_PATH = "/var/lock/{}_pci.lock" + TRANSITION_OPERATION_LOCK_FILE_PATH = "/var/lock/{}_transition.lock" SENSORD_OPERATION_LOCK_FILE_PATH = "/var/lock/sensord.lock" # Possible card types for modular chassis @@ -85,7 +90,6 @@ def __init__(self): self._thermal_list = [] self._voltage_sensor_list = [] self._current_sensor_list = [] - self.state_db_connector = None self.pci_bus_info = None # List of SfpBase-derived objects representing all sfps @@ -95,7 +99,21 @@ def __init__(self): # List of ASIC-derived objects representing all ASICs # visibile in PCI domain on the module self._asic_list = [] - + + # Initialize state database connector + self._state_db_connector = self._initialize_state_db_connector() + + def _initialize_state_db_connector(self): + """Initialize a STATE_DB connector using swsscommon only.""" + db = SonicV2Connector(use_string_keys=True) + try: + db.connect(db.STATE_DB) + except RuntimeError as e: + # Some environments autoconnect; preserve tolerant behavior + sys.stderr.write(f"Failed to connect to STATE_DB, continuing: {e}\n") + return None + return db + @contextlib.contextmanager def _file_operation_lock(self, lock_file_path): """Common file-based lock for operations using flock""" @@ -106,6 +124,13 @@ def _file_operation_lock(self, lock_file_path): finally: fcntl.flock(f.fileno(), fcntl.LOCK_UN) + @contextlib.contextmanager + def _transition_operation_lock(self): + """File-based lock for module state transition operations using flock""" + lock_file_path = self.TRANSITION_OPERATION_LOCK_FILE_PATH.format(self.get_name()) + with self._file_operation_lock(lock_file_path): + yield + @contextlib.contextmanager def _pci_operation_lock(self): """File-based lock for PCI operations using flock""" @@ -236,6 +261,57 @@ def set_admin_state(self, up): """ raise NotImplementedError + def set_admin_state_using_graceful_handler(self, up): + """ + Request to set the module's administrative state with graceful shutdown coordination. + + This function is specifically designed for SmartSwitch platforms and should be + called by chassisd to ensure proper graceful shutdown coordination with external + agents (e.g., gNOI clients) before setting admin state to DOWN. + + For non-SmartSwitch platforms or direct platform API usage, use set_admin_state() + instead. + + Args: + up (bool): True for admin UP, False for admin DOWN. + + Returns: + bool: True if the request was successful, False otherwise. + """ + if up: + # Admin UP: Set transition state to 'startup' before admin state change + module_name = self.get_name() + self.set_module_state_transition(self._state_db_connector, module_name, "startup") + admin_state_success = self.set_admin_state(True) + + # Clear transition state after admin state operation completes + if not self.clear_module_state_transition(self._state_db_connector, module_name): + context = "after successful admin state change" if admin_state_success else "after failed admin state change" + sys.stderr.write(f"Failed to clear transition state for module {module_name} {context}.\n") + + return admin_state_success + + # Admin DOWN: Perform graceful shutdown first + module_name = self.get_name() + graceful_success = self.graceful_shutdown_handler() + + if not graceful_success: + # Clear transition state on graceful shutdown failure + if not self.clear_module_state_transition(self._state_db_connector, module_name): + sys.stderr.write(f"Failed to clear transition state for module {module_name} after graceful shutdown failure.\n") + sys.stderr.write(f"Aborting admin-down for module {module_name} due to graceful shutdown failure.\n") + return False + + # Proceed with admin state change + admin_state_success = self.set_admin_state(False) + + # Always clear transition state after admin state operation completes + if not self.clear_module_state_transition(self._state_db_connector, module_name): + context = "after successful admin state change" if admin_state_success else "after failed admin state change" + sys.stderr.write(f"Failed to clear transition state for module {module_name} {context}.\n") + + return admin_state_success + def get_maximum_consumed_power(self): """ Retrives the maximum power drawn by this module @@ -337,21 +413,22 @@ def pci_entry_state_db(self, pcie_string, operation): Args: pcie_string (str): The PCI bus string to be written to state database operation (str): The operation being performed ("detaching" or "attaching") - - Raises: - RuntimeError: If state database connection fails """ try: - # Do not use import if swsscommon is not needed - import swsscommon - PCIE_DETACH_INFO_TABLE_KEY = PCIE_DETACH_INFO_TABLE+"|"+pcie_string - if not self.state_db_connector: - self.state_db_connector = swsscommon.swsscommon.DBConnector("STATE_DB", 0) + db = self._state_db_connector + PCIE_DETACH_INFO_TABLE_KEY = PCIE_DETACH_INFO_TABLE + "|" + pcie_string + if operation == PCIE_OPERATION_ATTACHING: - self.state_db_connector.delete(PCIE_DETACH_INFO_TABLE_KEY) + # Delete the entire entry for attaching operation + if hasattr(db, 'delete'): + db.delete(db.STATE_DB, PCIE_DETACH_INFO_TABLE_KEY, "bus_info") + db.delete(db.STATE_DB, PCIE_DETACH_INFO_TABLE_KEY, "dpu_state") return - self.state_db_connector.hset(PCIE_DETACH_INFO_TABLE_KEY, "bus_info", pcie_string) - self.state_db_connector.hset(PCIE_DETACH_INFO_TABLE_KEY, "dpu_state", operation) + # Set the PCI detach info for detaching operation + db.set(db.STATE_DB, PCIE_DETACH_INFO_TABLE_KEY, { + "bus_info": pcie_string, + "dpu_state": operation + }) except Exception as e: sys.stderr.write("Failed to write pcie bus info to state database: {}\n".format(str(e))) @@ -391,6 +468,287 @@ def pci_reattach(self): """ raise NotImplementedError + # ########################################### + # Smartswitch DPU graceful shutdown helpers + # Transition timeout defaults (seconds) + # These are used unless overridden by /usr/share/sonic/platform/platform.json + # with optional keys: dpu_startup_timeout, dpu_shutdown_timeout, dpu_reboot_timeout + # ########################################### + _TRANSITION_TIMEOUT_DEFAULTS = { + "startup": 300, # 5 minutes + "shutdown": 180, # 3 minutes + "reboot": 240, # 4 minutes + } + # class-level cache to avoid multiple reads per process + _TRANSITION_TIMEOUTS_CACHE = None + + + + def _transition_key(self) -> str: + """Return the STATE_DB key for this module's transition state.""" + # Use get_name() to avoid relying on an attribute that may not exist. + return f"CHASSIS_MODULE_TABLE|{self.get_name()}" + + def _load_transition_timeouts(self) -> dict: + """ + Load per-operation timeouts from /usr/share/sonic/platform/platform.json if present, + otherwise fall back to _TRANSITION_TIMEOUT_DEFAULTS. + + Recognized keys in platform.json: + - dpu_startup_timeout + - dpu_shutdown_timeout + - dpu_reboot_timeout + + Note: + The path used is /usr/share/sonic/platform/platform.json, which may differ from the typical + SONiC platform file location (/usr/share/sonic/device/{plat}/platform.json). This path is + bind-mounted in PMON/containers and is used directly here. + """ + if ModuleBase._TRANSITION_TIMEOUTS_CACHE is not None: + return ModuleBase._TRANSITION_TIMEOUTS_CACHE + + timeouts = dict(self._TRANSITION_TIMEOUT_DEFAULTS) + try: + # The platform.json file is expected at /usr/share/sonic/platform/platform.json. + # This may differ from the typical SONiC device path. + path = "/usr/share/sonic/platform/platform.json" + with open(path, "r") as f: + data = json.load(f) or {} + + if "dpu_startup_timeout" in data: + timeouts["startup"] = int(data["dpu_startup_timeout"]) + if "dpu_shutdown_timeout" in data: + timeouts["shutdown"] = int(data["dpu_shutdown_timeout"]) + if "dpu_reboot_timeout" in data: + timeouts["reboot"] = int(data["dpu_reboot_timeout"]) + except Exception as e: + # On any error, just use defaults + sys.stderr.write(f"Failed to load transition timeouts from platform.json, using defaults: {e}\n") + + ModuleBase._TRANSITION_TIMEOUTS_CACHE = timeouts + return ModuleBase._TRANSITION_TIMEOUTS_CACHE + + def graceful_shutdown_handler(self): + """ + SmartSwitch graceful shutdown gate for DPU modules with race condition protection. + + Coordinates shutdown with external agents (e.g., gNOI clients) by: + 1. Atomically setting CHASSIS_MODULE_TABLE| transition state to "shutdown" + 2. Waiting for external completion signal or module offline status + 3. Cleaning up transition state on completion or timeout + + Race Condition Handling: + - File-based locking ensures only one agent can modify transition state at a time + - Multiple concurrent calls are serialized through set_module_state_transition() + - Timed-out transitions are automatically cleared and new ones can proceed + - Timeout based on database-recorded start time, not individual agent wait time + + Exit Conditions: + - External agent sets state_transition_in_progress="False" (graceful completion) + - Module operational status becomes "Offline" (platform-detected shutdown) + - Timeout after configured period (default: 180s from platform.json dpu_shutdown_timeout) + + Returns: + bool: True if graceful shutdown completes, False on timeout. + + Note: + Called by platform set_admin_state() when transitioning DPU to admin DOWN. + Implements SONiC SmartSwitch graceful shutdown HLD requirements. + """ + db = self._state_db_connector + + module_name = self.get_name() + + # Atomically set transition state (handles race conditions with locking) + # Note: This is safe to call even if caller already set transition state, + # as the function is idempotent and will not overwrite existing valid transitions + self.set_module_state_transition(db, module_name, "shutdown") + + # Determine shutdown timeout (do NOT use get_reboot_timeout()) + timeouts = self._load_transition_timeouts() + shutdown_timeout = int(timeouts.get("shutdown", self._TRANSITION_TIMEOUT_DEFAULTS["shutdown"])) + + interval = 2 + waited = 0 + + key = self._transition_key() + while waited < shutdown_timeout: + # Get current transition state + entry = db.get_all(db.STATE_DB, key) or {} + + # (a) Someone else completed the graceful phase + if entry.get("state_transition_in_progress", "False") == "False": + return True + + # (b) Platform reports oper Offline — complete & clear transition + try: + oper = self.get_oper_status() + if oper and str(oper).lower() == "offline": + if not self.clear_module_state_transition(db, module_name): + sys.stderr.write(f"Graceful shutdown for module {module_name} failed to clear transition state.\n") + return True + except Exception as e: + # Don't fail the graceful gate on a transient platform call error + sys.stderr.write("Graceful shutdown for module {} failed to get oper status: {}\n".format(module_name, str(e))) + + # Check if the transition has timed out based on the recorded start time + # This handles cases where multiple agents might be waiting + if self.is_module_state_transition_timed_out(db, module_name, shutdown_timeout): + # Clear only if we can confirm it's actually timed out + if not self.clear_module_state_transition(db, module_name): + sys.stderr.write(f"Graceful shutdown for module {module_name} timed out and failed to clear transition state.\n") + else: + sys.stderr.write("Graceful shutdown for module {} timed out.\n".format(module_name)) + return False + + time.sleep(interval) + waited += interval + + # Final timeout check before clearing - use recorded start time, not our local wait time + if self.is_module_state_transition_timed_out(db, module_name, shutdown_timeout): + if not self.clear_module_state_transition(db, module_name): + sys.stderr.write(f"Graceful shutdown for module {module_name} timed out and failed to clear transition state.\n") + else: + sys.stderr.write("Graceful shutdown for module {} timed out.\n".format(module_name)) + + return False + + # ############################################################ + # Centralized APIs for CHASSIS_MODULE_TABLE transition flags + # ############################################################ + + def set_module_state_transition(self, db, module_name: str, transition_type: str): + """ + Atomically mark the given module as being in a state transition if not already in progress. + + This function is thread-safe and prevents race conditions when multiple agents + (chassis_modules.py, chassisd, reboot) attempt to set module state transitions + simultaneously by using a file-based lock. + + Args: + db: Connected SonicV2Connector + module_name: e.g., 'DPU0' + transition_type: 'shutdown' | 'startup' | 'reboot' + + Returns: + bool: True if transition was successfully set, False if already in progress + """ + allowed = {"shutdown", "startup", "reboot"} + ttype = (transition_type or "").strip().lower() + if ttype not in allowed: + sys.stderr.write(f"Invalid transition_type='{transition_type}' for module {module_name}\n") + return False + + module = module_name.strip().upper() + key = f"CHASSIS_MODULE_TABLE|{module}" + with self._transition_operation_lock(): + # Check if a transition is already in progress + existing_entry = db.get_all(db.STATE_DB, key) or {} + if existing_entry.get("state_transition_in_progress", "False").lower() in ("true", "1", "yes", "on"): + # Already in progress - check if it's timed out + timeout_seconds = int(self._load_transition_timeouts().get( + existing_entry.get("transition_type", "shutdown"), + self._TRANSITION_TIMEOUT_DEFAULTS.get("shutdown", 180) + )) + + if not self.is_module_state_transition_timed_out(db, module_name, timeout_seconds): + # Still valid, don't overwrite + return False + + # Timed out, clear and proceed with new transition + if not self.clear_module_state_transition(db, module_name): + sys.stderr.write(f"Failed to clear timed-out transition for module {module_name} before setting new one.\n") + return False + # Set new transition atomically + db.hset(db.STATE_DB, key, "state_transition_in_progress", "True") + db.hset(db.STATE_DB, key, "transition_type", ttype) + db.hset(db.STATE_DB, key, "transition_start_time", datetime.now(timezone.utc).isoformat()) + return True + + def clear_module_state_transition(self, db, module_name: str): + """ + Clear transition flags for the given module after a transition completes. + Field-scoped update to avoid clobbering concurrent writers. + + This function is thread-safe and uses the same lock as set_module_state_transition + to prevent race conditions. + + Args: + db: Connected SonicV2Connector. + module_name: The name of the module (e.g., 'DPU0'). + + Returns: + bool: True if the transition state was cleared successfully, False otherwise. + """ + with self._transition_operation_lock(): + key = f"CHASSIS_MODULE_TABLE|{module_name}" + try: + # Mark not in-progress and clear type (prevents stale 'startup' blocks) + db.hset(db.STATE_DB, key, "state_transition_in_progress", "False") + db.hset(db.STATE_DB, key, "transition_type", "") + db.hset(db.STATE_DB, key, "transition_start_time", "") + return True + except Exception as e: + sys.stderr.write(f"Failed to clear module state transition for {module_name}: {e}\n") + return False + + def get_module_state_transition(self, db, module_name: str) -> dict: + """ + Return the transition entry for a given module from STATE_DB. + + Note: This is a read-only operation and doesn't require locking. + + Returns: + dict with keys: state_transition_in_progress, transition_type, + transition_start_time (if present). + """ + key = f"CHASSIS_MODULE_TABLE|{module_name}" + return db.get_all(db.STATE_DB, key) or {} + + def is_module_state_transition_timed_out(self, db, module_name: str, timeout_seconds: int) -> bool: + """ + Check whether the state transition for the given module has exceeded timeout. + + Note: This is a read-only operation and doesn't require locking. + + Args: + db: Connected SonicV2Connector + module_name: e.g., 'DPU0' + timeout_seconds: max allowed seconds for the transition + + Returns: + True if transition exceeded timeout, False otherwise. + """ + entry = self.get_module_state_transition(db, module_name) + + # Missing entry means no active transition recorded; allow new operation to proceed. + if not entry: + return True + + # Only consider timeout if a transition is actually in progress + inprog = str(entry.get("state_transition_in_progress", "")).lower() in ("1", "true", "yes", "on") + if not inprog: + return True + + start_str = entry.get("transition_start_time") + if not start_str: + # If no start time, assume it's not timed out to be safe + return False + + # Parse ISO format datetime with timezone + try: + t0 = datetime.fromisoformat(start_str) + except Exception: + # Bad format → fail-safe to timed out + return True + + if t0.tzinfo is None: + # If timezone-naive, assume UTC + t0 = t0.replace(tzinfo=timezone.utc) + + age = (datetime.now(timezone.utc) - t0).total_seconds() + return age > timeout_seconds + ############################################## # Component methods ############################################## diff --git a/tests/module_base_test.py b/tests/module_base_test.py index 025849e9f..62ef9cc69 100644 --- a/tests/module_base_test.py +++ b/tests/module_base_test.py @@ -1,11 +1,16 @@ +import unittest +from unittest.mock import patch, MagicMock, call from sonic_platform_base.module_base import ModuleBase -import pytest -import json -import os import fcntl -from unittest.mock import patch, MagicMock, call +import importlib +import builtins from io import StringIO +import sys +import os import shutil +import contextlib +from types import ModuleType + class MockFile: def __init__(self, data=None): @@ -36,13 +41,13 @@ class TestModuleBase: def test_module_base(self): module = ModuleBase() not_implemented_methods = [ - [module.get_dpu_id], - [module.get_reboot_cause], - [module.get_state_info], - [module.get_pci_bus_info], - [module.pci_detach], - [module.pci_reattach], - ] + [module.get_dpu_id], + [module.get_reboot_cause], + [module.get_state_info], + [module.get_pci_bus_info], + [module.pci_detach], + [module.pci_reattach], + ] for method in not_implemented_methods: exception_raised = False @@ -57,35 +62,333 @@ def test_module_base(self): def test_sensors(self): module = ModuleBase() - assert(module.get_num_voltage_sensors() == 0) - assert(module.get_all_voltage_sensors() == []) - assert(module.get_voltage_sensor(0) == None) + assert module.get_num_voltage_sensors() == 0 + assert module.get_all_voltage_sensors() == [] + assert module.get_voltage_sensor(0) is None module._voltage_sensor_list = ["s1"] - assert(module.get_all_voltage_sensors() == ["s1"]) - assert(module.get_voltage_sensor(0) == "s1") - assert(module.get_num_current_sensors() == 0) - assert(module.get_all_current_sensors() == []) - assert(module.get_current_sensor(0) == None) + assert module.get_all_voltage_sensors() == ["s1"] + assert module.get_voltage_sensor(0) == "s1" + assert module.get_num_current_sensors() == 0 + assert module.get_all_current_sensors() == [] + assert module.get_current_sensor(0) is None module._current_sensor_list = ["s1"] - assert(module.get_all_current_sensors() == ["s1"]) - assert(module.get_current_sensor(0) == "s1") - + assert module.get_all_current_sensors() == ["s1"] + assert module.get_current_sensor(0) == "s1" + + +class DummyModule(ModuleBase): + def __init__(self, name="DPU0"): + self.name = name + # Mock the _state_db_connector to avoid swsscommon dependency in tests + self._state_db_connector = MagicMock() + + def get_name(self): + return self.name + + def set_admin_state(self, up): + return True # Dummy override + + +class TestModuleBaseGracefulShutdown: + # ==== graceful shutdown tests (match timeouts + centralized helpers) ==== + + @patch("sonic_platform_base.module_base.time", create=True) + def test_graceful_shutdown_handler_success(self, mock_time): + dpu_name = "DPU0" + mock_time.time.return_value = 1710000000 + mock_time.sleep.return_value = None + + module = DummyModule(name=dpu_name) + module._state_db_connector.get_all.side_effect = [ + {"state_transition_in_progress": "True"}, + {"state_transition_in_progress": "False"}, + ] + + # Mock the race condition protection to allow the transition to be set + with patch.object(module, "get_name", return_value=dpu_name), \ + patch.object(module, "_load_transition_timeouts", return_value={"shutdown": 10}), \ + patch.object(module, "set_module_state_transition", return_value=True), \ + patch.object(module, "is_module_state_transition_timed_out", return_value=False): + result = module.graceful_shutdown_handler() + assert result is True + + @patch("sonic_platform_base.module_base.time", create=True) + def test_graceful_shutdown_handler_timeout(self, mock_time): + dpu_name = "DPU1" + mock_time.time.return_value = 1710000000 + mock_time.sleep.return_value = None + + module = DummyModule(name=dpu_name) + # Keep it perpetually "in progress" so the handler’s wait path runs + module._state_db_connector.get_all.return_value = { + "state_transition_in_progress": "True", + "transition_type": "shutdown", + "transition_start_time": "2024-01-01T00:00:00", + } + + with patch.object(module, "get_name", return_value=dpu_name), \ + patch.object(module, "_load_transition_timeouts", return_value={"shutdown": 5}), \ + patch.object(module, "set_module_state_transition", return_value=True), \ + patch.object(module, "is_module_state_transition_timed_out", return_value=True): + result = module.graceful_shutdown_handler() + assert result is False + + @staticmethod + @patch("sonic_platform_base.module_base.time", create=True) + def test_graceful_shutdown_handler_offline_clear(mock_time): + mock_time.time.return_value = 123456789 + mock_time.sleep.return_value = None + + module = DummyModule(name="DPUX") + module._state_db_connector.get_all.return_value = { + "state_transition_in_progress": "True", + "transition_type": "shutdown", + "transition_start_time": "2024-01-01T00:00:00", + } + + with patch.object(module, "get_name", return_value="DPUX"), \ + patch.object(module, "get_oper_status", return_value="Offline"), \ + patch.object(module, "_load_transition_timeouts", return_value={"shutdown": 5}), \ + patch.object(module, "is_module_state_transition_timed_out", return_value=False), \ + patch.object(module, "set_module_state_transition", return_value=True): + result = module.graceful_shutdown_handler() + assert result is True + + @staticmethod + def test_transition_timeouts_platform_missing(): + """If platform is missing, defaults are used.""" + from sonic_platform_base import module_base as mb + class Dummy(mb.ModuleBase): ... + mb.ModuleBase._TRANSITION_TIMEOUTS_CACHE = None + with patch("os.path.exists", return_value=False): + d = Dummy() + assert d._load_transition_timeouts()["reboot"] == mb.ModuleBase._TRANSITION_TIMEOUT_DEFAULTS["reboot"] + + @staticmethod + def test_transition_timeouts_reads_value(): + """platform.json dpu_reboot_timeout and dpu_shutdown_timeout are honored.""" + from sonic_platform_base import module_base as mb + from unittest import mock + class Dummy(mb.ModuleBase): ... + mb.ModuleBase._TRANSITION_TIMEOUTS_CACHE = None + with patch("os.path.exists", return_value=True), \ + patch("builtins.open", new_callable=mock.mock_open, + read_data='{"dpu_reboot_timeout": 42, "dpu_shutdown_timeout": 123}'): + d = Dummy() + assert d._load_transition_timeouts()["reboot"] == 42 + assert d._load_transition_timeouts()["shutdown"] == 123 + + @staticmethod + def test_transition_timeouts_open_raises(): + """On read error, defaults are used.""" + from sonic_platform_base import module_base as mb + class Dummy(mb.ModuleBase): ... + mb.ModuleBase._TRANSITION_TIMEOUTS_CACHE = None + with patch("os.path.exists", return_value=True), \ + patch("builtins.open", side_effect=FileNotFoundError): + d = Dummy() + assert d._load_transition_timeouts()["reboot"] == mb.ModuleBase._TRANSITION_TIMEOUT_DEFAULTS["reboot"] + + # ==== coverage: centralized transition helpers ==== + + def test_transition_key_uses_get_name(self, monkeypatch): + m = ModuleBase() + monkeypatch.setattr(m, "get_name", lambda: "DPUX", raising=False) + assert m._transition_key() == "CHASSIS_MODULE_TABLE|DPUX" + + def test_set_module_state_transition_writes_expected_fields(self): + module = DummyModule() + module._state_db_connector.get_all.return_value = {} + + with patch.object(module, '_transition_operation_lock', side_effect=contextlib.nullcontext): + result = module.set_module_state_transition(module._state_db_connector, "DPU9", "startup") + + assert result is True # Should successfully set the transition + + # Check that 'hset' was called with the correct arguments + expected_calls = [ + call(module._state_db_connector.STATE_DB, "CHASSIS_MODULE_TABLE|DPU9", "state_transition_in_progress", "True"), + call(module._state_db_connector.STATE_DB, "CHASSIS_MODULE_TABLE|DPU9", "transition_type", "startup"), + call(module._state_db_connector.STATE_DB, "CHASSIS_MODULE_TABLE|DPU9", "transition_start_time", unittest.mock.ANY), + ] + module._state_db_connector.hset.assert_has_calls(expected_calls, any_order=True) + + def test_set_module_state_transition_invalid_type(self): + module = DummyModule() + module._state_db_connector.get_all.return_value = {} + + with patch('sys.stderr', new_callable=StringIO) as mock_stderr: + result = module.set_module_state_transition(module._state_db_connector, "DPU9", "invalid_type") + assert result is False + assert "Invalid transition_type" in mock_stderr.getvalue() + module._state_db_connector.hset.assert_not_called() + + def test_set_module_state_transition_race_condition_protection(self, monkeypatch): + module = DummyModule() + module._state_db_connector.get_all.return_value = { + "state_transition_in_progress": "True", + "transition_type": "shutdown", + "transition_start_time": "..." + } + + def fake_is_timed_out(db, module_name, timeout_seconds): + # This is the check inside set_module_state_transition + return False # Not timed out + + monkeypatch.setattr(module, "is_module_state_transition_timed_out", fake_is_timed_out, raising=False) + + # Mock _load_transition_timeouts to avoid file access + monkeypatch.setattr(module, "_load_transition_timeouts", lambda: {"shutdown": 180}) + with patch.object(module, '_transition_operation_lock', side_effect=contextlib.nullcontext): + result = module.set_module_state_transition(module._state_db_connector, "DPU9", "shutdown") + + assert result is False # Should fail to set due to existing active transition + + def test_clear_module_state_transition_success(self): + module = DummyModule() + + with patch.object(module, '_transition_operation_lock', side_effect=contextlib.nullcontext): + result = module.clear_module_state_transition(module._state_db_connector, "DPU9") + + assert result is True + + # Check that 'hset' was called to clear the flags + expected_calls = [ + call(module._state_db_connector.STATE_DB, "CHASSIS_MODULE_TABLE|DPU9", "state_transition_in_progress", "False"), + call(module._state_db_connector.STATE_DB, "CHASSIS_MODULE_TABLE|DPU9", "transition_type", ""), + call(module._state_db_connector.STATE_DB, "CHASSIS_MODULE_TABLE|DPU9", "transition_start_time", ""), + ] + module._state_db_connector.hset.assert_has_calls(expected_calls, any_order=True) + + def test_clear_module_state_transition_failure(self, monkeypatch): + module = DummyModule() + module._state_db_connector.hset.side_effect = Exception("DB error") + + with patch.object(module, '_transition_operation_lock', side_effect=contextlib.nullcontext), \ + patch('sys.stderr', new_callable=StringIO) as mock_stderr: + result = module.clear_module_state_transition(module._state_db_connector, "DPU9") + assert result is False + assert "Failed to clear module state transition" in mock_stderr.getvalue() + + def test_get_module_state_transition_passthrough(self): + expect = {"state_transition_in_progress": "True", "transition_type": "reboot"} + module = DummyModule() + module._state_db_connector.get_all.return_value = expect + got = module.get_module_state_transition(module._state_db_connector, "DPU5") + assert got is expect + + # ==== coverage: is_module_state_transition_timed_out variants ==== + + def test_is_transition_timed_out_not_in_progress(self, monkeypatch): + module = DummyModule() + monkeypatch.setattr( + module, "get_module_state_transition", + lambda *_: {"state_transition_in_progress": "False"}, + raising=False + ) + # If not in progress, it's not timed out (it's completed) + assert module.is_module_state_transition_timed_out(object(), "DPU0", 1) + + def test_is_transition_timed_out_no_entry(self, monkeypatch): + module = DummyModule() + monkeypatch.setattr(module, "get_module_state_transition", lambda *_: {}, raising=False) + assert module.is_module_state_transition_timed_out(object(), "DPU0", 1) + + def test_is_transition_timed_out_no_start_time(self, monkeypatch): + module = DummyModule() + monkeypatch.setattr( + module, "get_module_state_transition", lambda *_: {"state_transition_in_progress": "True"}, raising=False + ) + # Current implementation returns False when no start time is present (to be safe) + assert not module.is_module_state_transition_timed_out(object(), "DPU0", 1) + + def test_is_transition_timed_out_bad_timestamp(self, monkeypatch): + module = DummyModule() + monkeypatch.setattr( + module, "get_module_state_transition", + lambda *_: { + "state_transition_in_progress": "True", + "transition_start_time": "bad" + }, + raising=False + ) + assert module.is_module_state_transition_timed_out(object(), "DPU0", 1) + + def test_is_transition_timed_out_false(self, monkeypatch): + from datetime import datetime, timezone, timedelta + start = (datetime.now(timezone.utc) - timedelta(seconds=1)).isoformat() + module = DummyModule() + monkeypatch.setattr( + module, "get_module_state_transition", + lambda *_: { + "state_transition_in_progress": "True", + "transition_start_time": start + }, + raising=False + ) + assert not module.is_module_state_transition_timed_out(object(), "DPU0", 9999) + + def test_is_transition_timed_out_true(self, monkeypatch): + from datetime import datetime, timezone, timedelta + start = (datetime.now(timezone.utc) - timedelta(seconds=10)).isoformat() + module = DummyModule() + monkeypatch.setattr( + module, "get_module_state_transition", + lambda *_: { + "state_transition_in_progress": "True", + "transition_start_time": start + }, + raising=False + ) + assert module.is_module_state_transition_timed_out(object(), "DPU0", 1) + + # ==== coverage: import-time exposure of helper aliases ==== + @staticmethod + def test_helper_exports_exposed(): + # The helpers are available as methods on ModuleBase; importing + # them as top-level symbols is not required. Verify presence on class. + from sonic_platform_base.module_base import ModuleBase as MB + assert hasattr(MB, 'set_module_state_transition') and callable(getattr(MB, 'set_module_state_transition')) + assert hasattr(MB, 'clear_module_state_transition') and callable(getattr(MB, 'clear_module_state_transition')) + assert hasattr(MB, 'is_module_state_transition_timed_out') and callable(getattr(MB, 'is_module_state_transition_timed_out')) + + +class TestModuleBasePCIAndSensors: def test_pci_entry_state_db(self): - module = ModuleBase() - mock_connector = MagicMock() - module.state_db_connector = mock_connector - - module.pci_entry_state_db("0000:00:00.0", "detaching") - mock_connector.hset.assert_has_calls([ - call("PCIE_DETACH_INFO|0000:00:00.0", "bus_info", "0000:00:00.0"), - call("PCIE_DETACH_INFO|0000:00:00.0", "dpu_state", "detaching") - ]) - - module.pci_entry_state_db("0000:00:00.0", "attaching") - mock_connector.delete.assert_called_with("PCIE_DETACH_INFO|0000:00:00.0") - - mock_connector.hset.side_effect = Exception("DB Error") - module.pci_entry_state_db("0000:00:00.0", "detaching") + module = DummyModule() + + # Test "detaching" — implementation writes a dict with bus_info and dpu_state + module.pci_entry_state_db("0000:01:00.0", "detaching") + module._state_db_connector.set.assert_called_with( + module._state_db_connector.STATE_DB, + "PCIE_DETACH_INFO|0000:01:00.0", + { + "bus_info": "0000:01:00.0", + "dpu_state": "detaching" + } + ) + + # Test "attaching" — implementation deletes specific fields on attach + module.pci_entry_state_db("0000:02:00.0", "attaching") + module._state_db_connector.delete.assert_any_call( + module._state_db_connector.STATE_DB, + "PCIE_DETACH_INFO|0000:02:00.0", + "bus_info" + ) + module._state_db_connector.delete.assert_any_call( + module._state_db_connector.STATE_DB, + "PCIE_DETACH_INFO|0000:02:00.0", + "dpu_state" + ) + + def test_pci_entry_state_db_exception(self): + module = DummyModule() + module._state_db_connector.set.side_effect = Exception("DB write error") + + with patch('sys.stderr', new_callable=StringIO) as mock_stderr: + module.pci_entry_state_db("0000:01:00.0", "detaching") + # Implementation writes a more specific message + assert "Failed to write pcie bus info to state database" in mock_stderr.getvalue() def test_file_operation_lock(self): module = ModuleBase() @@ -262,3 +565,103 @@ def test_module_post_startup(self): with patch.object(module, 'handle_pci_rescan', return_value=True), \ patch.object(module, 'handle_sensor_addition', return_value=False): assert module.module_post_startup() is False + + +class TestStateDbConnectorSwsscommonOnly: + @patch('sonic_platform_base.module_base.SonicV2Connector') + def test_initialize_state_db_connector_success(self, mock_connector): + from sonic_platform_base.module_base import ModuleBase + mock_db = MagicMock() + mock_connector.return_value = mock_db + module = ModuleBase() + assert module._state_db_connector == mock_db + mock_db.connect.assert_called_once_with(mock_db.STATE_DB) + + @patch('sonic_platform_base.module_base.SonicV2Connector') + def test_initialize_state_db_connector_exception(self, mock_connector): + from sonic_platform_base.module_base import ModuleBase + mock_db = MagicMock() + mock_db.connect.side_effect = RuntimeError("Connection failed") + mock_connector.return_value = mock_db + + with patch('sys.stderr', new_callable=StringIO) as mock_stderr: + module = ModuleBase() + assert module._state_db_connector is None + assert "Failed to connect to STATE_DB" in mock_stderr.getvalue() + + def test_state_db_connector_uses_swsscommon_only(self): + import importlib + import sys + from types import ModuleType + from unittest.mock import patch + + # Fake swsscommon package + swsscommon.swsscommon module + pkg = ModuleType("swsscommon") + pkg.__path__ = [] # mark as package + sub = ModuleType("swsscommon.swsscommon") + + class FakeV2: + def connect(self, *_): + pass + + sub.SonicV2Connector = FakeV2 + + with patch.dict(sys.modules, { + "swsscommon": pkg, + "swsscommon.swsscommon": sub + }, clear=False): + mb = importlib.import_module("sonic_platform_base.module_base") + importlib.reload(mb) + # Since __init__ calls it, we need to patch before creating an instance + with patch.object(mb.ModuleBase, '_initialize_state_db_connector') as mock_init_db: + mock_init_db.return_value = FakeV2() + instance = mb.ModuleBase() + assert isinstance(instance._state_db_connector, FakeV2) + + +# New test cases for set_admin_state_using_graceful_handler logic +class TestModuleBaseAdminState: + def test_set_admin_state_up_sets_startup_transition(self): + module = DummyModule() + # Create a manager to check call order + manager = MagicMock() + module.set_module_state_transition = manager.set_module_state_transition + module.set_admin_state = manager.set_admin_state + module.clear_module_state_transition = manager.clear_module_state_transition + manager.set_admin_state.return_value = True + manager.clear_module_state_transition.return_value = True + + result = module.set_admin_state_using_graceful_handler(True) + + assert result is True + # Verify that set_module_state_transition is called before set_admin_state + expected_calls = [ + call.set_module_state_transition(module._state_db_connector, "DPU0", "startup"), + call.set_admin_state(True), + call.clear_module_state_transition(module._state_db_connector, "DPU0"), + ] + manager.assert_has_calls(expected_calls) + + def test_set_admin_state_up_clears_transition(self): + module = DummyModule() + module.set_admin_state = MagicMock(return_value=True) + module.clear_module_state_transition = MagicMock(return_value=True) + + result = module.set_admin_state_using_graceful_handler(True) + + assert result is True + module.set_admin_state.assert_called_once_with(True) + module.clear_module_state_transition.assert_called_once() + + def test_set_admin_state_down_success(self): + module = DummyModule() + module.graceful_shutdown_handler = MagicMock(return_value=True) + module.set_admin_state = MagicMock(return_value=True) + module.clear_module_state_transition = MagicMock(return_value=True) + + result = module.set_admin_state_using_graceful_handler(False) + + assert result is True + module.graceful_shutdown_handler.assert_called_once() + module.set_admin_state.assert_called_once_with(False) + assert module.clear_module_state_transition.call_count == 1 \ No newline at end of file