Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 48 additions & 15 deletions sonic-chassisd/scripts/chassisd
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ MODULE_ADMIN_UP = 1
MODULE_REBOOT_CAUSE_DIR = "/host/reboot-cause/module/"
MAX_HISTORY_FILES = 10

DP_STATE = 'dpu_data_plane_state'
DP_UPDATE_TIME = 'dpu_data_plane_time'
CP_STATE = 'dpu_control_plane_state'
CP_UPDATE_TIME = 'dpu_control_plane_time'


# This daemon should return non-zero exit code so that supervisord will
# restart it automatically.
exit_code = 0
Expand Down Expand Up @@ -812,18 +818,24 @@ class SmartSwitchModuleUpdater(ModuleUpdater):
def update_dpu_state(self, key, state):
"""
Update specific DPU state fields in chassisStateDB using the given key.
If state is 'down', set control plane, data plane states to down as well.
"""
try:
# Connect to the CHASSIS_STATE_DB using daemon_base
if not self.chassis_state_db:
self.chassis_state_db = daemon_base.db_connect("CHASSIS_STATE_DB")


# Prepare the fields to update
updates = {
"dpu_midplane_link_state": state,
"dpu_midplane_link_reason": "",
"dpu_midplane_link_time": get_formatted_time(),
}
# If midplane state is down, set control plane, data plane states to down as well
if state == "down":
updates[CP_STATE] = "down"
updates[DP_STATE] = "down"

# Update each field directly
for field, value in updates.items():
Expand Down Expand Up @@ -1170,11 +1182,6 @@ class SmartSwitchConfigManagerTask(ProcessTaskBase):

class DpuStateUpdater(logger.Logger):

DP_STATE = 'dpu_data_plane_state'
DP_UPDATE_TIME = 'dpu_data_plane_time'
CP_STATE = 'dpu_control_plane_state'
CP_UPDATE_TIME = 'dpu_control_plane_time'

def __init__(self, log_identifier, chassis):
super(DpuStateUpdater, self).__init__(log_identifier)

Expand Down Expand Up @@ -1229,12 +1236,12 @@ class DpuStateUpdater(logger.Logger):
return get_formatted_time()

def _update_dp_dpu_state(self, state):
self.dpu_state_table.hset(self.name, self.DP_STATE, state)
self.dpu_state_table.hset(self.name, self.DP_UPDATE_TIME, self._time_now())
self.dpu_state_table.hset(self.name, DP_STATE, state)
self.dpu_state_table.hset(self.name, DP_UPDATE_TIME, self._time_now())

def _update_cp_dpu_state(self, state):
self.dpu_state_table.hset(self.name, self.CP_STATE, state)
self.dpu_state_table.hset(self.name, self.CP_UPDATE_TIME, self._time_now())
self.dpu_state_table.hset(self.name, CP_STATE, state)
self.dpu_state_table.hset(self.name, CP_UPDATE_TIME, self._time_now())

def get_dp_state(self):
return 'up' if self._get_dp_state() else 'down'
Expand All @@ -1245,16 +1252,17 @@ class DpuStateUpdater(logger.Logger):
def update_state(self):

dp_current_state = self.get_dp_state()
_, dp_prev_state = self.dpu_state_table.hget(self.name, self.DP_STATE)
_, dp_prev_state = self.dpu_state_table.hget(self.name, DP_STATE)

if dp_current_state != dp_prev_state:
self._update_dp_dpu_state(dp_current_state)

cp_current_state = self.get_cp_state()
_, cp_prev_state = self.dpu_state_table.hget(self.name, self.CP_STATE)
_, cp_prev_state = self.dpu_state_table.hget(self.name, CP_STATE)

if cp_current_state != cp_prev_state:
self._update_cp_dpu_state(cp_current_state)
return [dp_current_state, cp_current_state]

def deinit(self):
self._update_dp_dpu_state('down')
Expand Down Expand Up @@ -1403,12 +1411,16 @@ class DpuStateManagerTask(ProcessTaskBase):
self.dpu_state_updater = dpu_state_updater
self.state_db = daemon_base.db_connect('STATE_DB')
self.app_db = daemon_base.db_connect('APPL_DB')
self.chassis_state_db = daemon_base.db_connect('CHASSIS_STATE_DB')
self.current_dp_state = None
self.current_cp_state = None

def task_worker(self):
sel = swsscommon.Select()
selectable = [
swsscommon.SubscriberStateTable(self.app_db, 'PORT_TABLE'),
swsscommon.SubscriberStateTable(self.state_db, 'SYSTEM_READY')
swsscommon.SubscriberStateTable(self.state_db, 'SYSTEM_READY'),
swsscommon.SubscriberStateTable(self.chassis_state_db, 'DPU_STATE')
]

for s in selectable:
Expand All @@ -1424,10 +1436,31 @@ class DpuStateManagerTask(ProcessTaskBase):
if state != swsscommon.Select.OBJECT:
continue

for s in selectable:
s.pops()
update_required = False

self.dpu_state_updater.update_state()
for s in selectable:
result = s.pop()
update_required = True # If there is any selectable object, we need to update the state
Comment on lines +1439 to +1443

Copilot AI May 20, 2025

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 'update_required' flag is being reinitialized for each selectable within the loop instead of accumulating conditions from all objects. Consider initializing 'update_required' to false before the loop and then OR-ing or accumulating the condition for each selectable to ensure proper state update decision.

Copilot uses AI. Check for mistakes.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The update_required is set this way, because it shows a clear distinction on when we perform the actual update, and when the update is disabled, I think implementing it this way is more clearer to understand

if result is None:
continue
key, op, fvp = result # Changed from _ to fvp to match what we use below
# Check if this is the DPU_STATE table
if s.getDbConnector().getDbName() == 'CHASSIS_STATE_DB':
# Don't update if this is a change for another DPU
if key != self.dpu_state_updater.name:
update_required = False
continue
if op == 'SET' and isinstance(fvp, tuple):
fvs = dict(fvp)
# No need to update if the state is the same as the current state
if ('dpu_data_plane_state' in fvs and fvs['dpu_data_plane_state'] == self.current_dp_state) and \
('dpu_control_plane_state' in fvs and fvs['dpu_control_plane_state'] == self.current_cp_state):
update_required = False
continue
self.logger.log_info(f"DPU_STATE change detected: operation={op}, key={key}")

if update_required:
[self.current_dp_state, self.current_cp_state] = self.dpu_state_updater.update_state()

except KeyboardInterrupt:
pass
Expand Down
9 changes: 9 additions & 0 deletions sonic-chassisd/tests/mock_swsscommon.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,15 @@ def pop(self):
def pops(self):
return None

def getDbConnector(self):
return MockDbConnector()


class MockDbConnector:

def getDbName(self):
return 'CHASSIS_STATE_DB'

class RedisPipeline:
def __init__(self, db):
self.db = db
Expand Down
64 changes: 64 additions & 0 deletions sonic-chassisd/tests/test_chassisd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1765,3 +1765,67 @@ def is_valid_date(date_str):
if not date_value:
AssertionError("Date is not set!")
assert is_valid_date(date_value)

def test_smartswitch_moduleupdater_midplane_state_change():
"""Test that when midplane goes down, control plane and data plane states are set to down"""
chassis = MockSmartSwitchChassis()
index = 0
name = "DPU0"
desc = "DPU Module 0"
slot = 0
serial = "DPU0-0000"
module_type = ModuleBase.MODULE_TYPE_DPU
module = MockModule(index, name, desc, module_type, slot, serial)
module.set_midplane_ip()
chassis.module_list.append(module)

# Create the updater
module_updater = SmartSwitchModuleUpdater(SYSLOG_IDENTIFIER, chassis)
module_updater.midplane_initialized = True

# Mock chassis_state_db
chassis_state_db = {}
def mock_hset(key, field, value):
if key not in chassis_state_db:
chassis_state_db[key] = {}
chassis_state_db[key][field] = value

def mock_hget(key, field):
if key in chassis_state_db and field in chassis_state_db[key]:
return chassis_state_db[key][field]
return None

with patch.object(module_updater, 'chassis_state_db') as mock_db:
mock_db.hset = MagicMock(side_effect=mock_hset)
mock_db.hget = MagicMock(side_effect=mock_hget)

# Initially set midplane as up
module.set_midplane_reachable(True)
module_updater.check_midplane_reachability()

# Verify initial state
key = "DPU_STATE|" + name
assert chassis_state_db[key]["dpu_midplane_link_state"] == "up"

# Now set midplane as down
module.set_midplane_reachable(False)
module_updater.check_midplane_reachability()

# Verify all states are set to down
assert chassis_state_db[key]["dpu_midplane_link_state"] == "down"
assert chassis_state_db[key]["dpu_control_plane_state"] == "down"
assert chassis_state_db[key]["dpu_data_plane_state"] == "down"

# Verify timestamps are set
assert "dpu_midplane_link_time" in chassis_state_db[key]

# Verify time format
date_format = "%a %b %d %I:%M:%S %p UTC %Y"
def is_valid_date(date_str):
try:
datetime.strptime(date_str, date_format)
return True
except ValueError:
return False

assert is_valid_date(chassis_state_db[key]["dpu_midplane_link_time"])
Loading
Loading