-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathRSSDA.py
More file actions
3256 lines (2671 loc) · 151 KB
/
Copy pathRSSDA.py
File metadata and controls
3256 lines (2671 loc) · 151 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
Exact/Approximate Recursive Small-Step Semi-Decentralized A* (RS-SDA*)
Communication may be synchronized conditioned on states, joint actions, or joint observations. May be further extended to beliefs.
This module implements the RsSSDA* algorithm (exact and approximate) for solving semi-decentralized POMDPs (SDec-POMDPs)
with state/joint action/joint observation-based communication triggers. The algorithm handles systems where agents
operate decentralized by default but synchronize intermittently. The current version is designed for two-agent systems, but can be extended.
Algorithm Overview:
RS-SDA* performs A* search over the space of partial policies, expanding one agent's
action assignment at a time. The key innovation is tracking two parallel probability
flows: decentralized (agents act independently) and centralized (agents coordinate).
State-based triggers partition the belief at each stage based on which states require
synchronization.
Key Components:
- Policy: Represents partial/complete policies with decentralized and centralized
action mappings, observation history clustering, and belief distributions.
- SDecPOMDPModel: Encapsulates the SDec-POMDP model (transitions, observations, rewards).
- SDecPOMDP: Main solver class implementing the A* search with heuristic computation.
Optional Approximation Techniques (TI1-TI4):
- TI1: Early termination via weighted majority voting on centralization decisions; effective for stochastic observations and in centralized settings
- TI2: Progress-based pruning to limit search depth per policy
- TI3: Recursive horizon limiting with tail approximation [QMDP/HYBRID/POMDP]
- TI4: Max clustering; effective for large/continous/stochastic observations/complex belief dynamics
Author: [Mahdi Al-Husseini]
License: MIT (https://opensource.org/license/mit/)
"""
from dataclasses import dataclass, field
from heapq import heappush, heappop, nsmallest, heapify
from itertools import count
import math
import sys
import numpy as np
from numba import jit
from typing import List, Tuple, Dict, Optional, Union, Any
import psutil
# ==========================================
# Type Definitions & Constants
# ==========================================
# Primitives matching paper notation
BeliefID = int # Index into self.dists
StateID = int # Index into nstates
ActionID = int # Single agent action index
JointActionID = int # Joint action index
ObsID = int # Observation index
Prob = float # Probability [0, 1]
# Policy Structures
# Structure: [Stage][Decentralized=0/Centralized=1]
# Dec: [AgentIdx][HistoryIdx] -> ActionID
# Cen: [ClusterIdx] -> [JointActionID]
DecentralizedPol = List[List[int]]
CentralizedPol = List[List[int]]
StagePolicy = List[Union[List[DecentralizedPol], CentralizedPol]]
FullPolicy = List[StagePolicy]
# Constants
EPSILON = 1e-12 # Probability threshold for numerical stability (near-zero checks)
TOLERANCE = 1e-8 # Value comparison tolerance for A* search pruning
HASH_FACTOR = 10**12 + 39 # Factor for hashing probability distributions
class MemoryLimitExceeded(Exception):
"""Raised when memory usage exceeds the configured limit."""
pass
def get_memory_usage_gb() -> float:
"""Returns current process memory usage in GB."""
process = psutil.Process()
return process.memory_info().rss / (1024 ** 3)
def fast_dynamics_sparse(belief, T_csr_list, O_csr_list, nactions, nstates, nobs):
"""
Optimized sparse dynamics using flatnonzero to skip zero-probability states.
"""
p_obs_all = np.zeros((nactions, nobs), dtype=np.float64)
joint_unnorm_all = np.zeros((nactions, nobs, nstates), dtype=np.float64)
# 1. Compute Next State Distribution
for a in range(nactions):
# belief (dense) @ T (sparse) -> next_states (dense)
next_states = belief @ T_csr_list[a]
# 2. OPTIMIZATION: Get indices of non-zero states immediately
active_states = np.flatnonzero(next_states > EPSILON)
if active_states.size == 0:
continue
# 3. Iterate only active states
O_a = O_csr_list[a]
for s_prime in active_states:
val_ns = next_states[s_prime]
start = O_a.indptr[s_prime]
end = O_a.indptr[s_prime + 1]
if start == end:
continue
for idx in range(start, end):
o = O_a.indices[idx]
prob = val_ns * O_a.data[idx]
joint_unnorm_all[a, o, s_prime] = prob
p_obs_all[a, o] += prob
return p_obs_all, joint_unnorm_all
@jit(nopython=True, cache=True)
def fast_dynamics(belief, T, O, nactions, nstates, nobs):
"""
Optimized dense dynamic calculations.
"""
# PRE-ALLOCATE OUTPUTS
next_states_all = np.zeros((nactions, nstates), dtype=np.float64)
p_obs_all = np.zeros((nactions, nobs), dtype=np.float64)
joint_unnorm_all = np.zeros((nactions, nobs, nstates), dtype=np.float64)
# 1. COMPUTE NEXT STATE DISTRIBUTION
for a in range(nactions):
for s in range(nstates):
val = belief[s]
if val > 1e-12: # Numba JIT requires literal constant
for s_prime in range(nstates):
next_states_all[a, s_prime] += val * T[a, s, s_prime]
# 2. COMPUTE OBSERVATION PROBABILITIES & JOINT BELIEFS
for a in range(nactions):
for s_prime in range(nstates):
val_ns = next_states_all[a, s_prime]
if val_ns > 0:
for o in range(nobs):
prob = val_ns * O[a, s_prime, o]
joint_unnorm_all[a, o, s_prime] = prob
p_obs_all[a, o] += prob
return p_obs_all, joint_unnorm_all
def int_tuple(plist: Union[List[float], np.ndarray, None]) -> Tuple[int, ...]:
if plist is None:
return tuple()
# Create a sparse tuple representation (index, value) ignoring zeros.
# SDecPOMDP.get_init uses this same helper so the initial belief and
# successor beliefs share one dictionary key format.
return tuple(
x * HASH_FACTOR + int(float(plist[x]) * HASH_FACTOR)
for x in range(len(plist))
if plist[x] > 0
)
def cumprod(lens): # cumprod takes an array [l_0, l_1, , l_{n-1}] and returns the array of cumulative products
ll = len(lens) # [1, l_0, l_0l_1, , l_0l_1l_{n-2}], as well as the full product l_0l_1l_{n-1} separately.
div = [1]*ll
for idx in range(ll-1):
div[idx+1] = div[idx] * lens[idx]
return(div, div[ll-1]*lens[ll-1])
def product(list1, list2, mult): # given lists L1, L2, computes the list of indices x + My
return [x+mult*y for y in list2 for x in list1] # corresponding to pairs (x, y), where x in L1, y in L2.
def lists_product(lists, mults, nlists): # given a list L of lists L_0, , L_{n-1}, computes the list of indices
list1 = lists[0] # x_0 + M_1 x_1 + M_{n-1} x_{n-1} corresponding to tuples (x_0, x_1, , x_{n-1})
for idx in range(1, nlists): # where x_0 in L_0, x_1 in L_1, ., L_{n-1}.
if lists[idx] != [-1]:
m = mults[idx]
l_curr = lists[idx]
list1 = [x + m*y for y in l_curr for x in list1]
return list1
def lists_product2(list1idx, list1, rlens, mults, nlists): # Computes a list_product where only one list is not a range.
return lists_product([list1 if idx == list1idx else range(rlens[idx]) for idx in range(nlists)], mults, nlists)
class Policy:
def __init__(self, policy: FullPolicy, ncluster: List[List[int]], dists: List[List[BeliefID]], prob: List[List[Prob]],
clustering: List[Tuple], values: List[float] = [], heuristics: List[float] = [math.inf], depth: Optional[int] = None,
final_cidx_value: float = 0.0, dists_cen: List[List[BeliefID]] = [[-1]], prob_cen: List[List[Prob]] = [[1.0]],
dec_split: List[float] = [1.0], clustering_cen: List[Tuple] = [], step_cen_values: List[float] = [], suffixes: List[Any] = []):
self.policy = policy
self.ncluster = ncluster
self.dists = dists
self.prob = prob
self.clustering = clustering
self.clustering_cen = clustering_cen
self.values = values
self.heuristics = heuristics
self.depth = depth
self.dists_cen = dists_cen
self.prob_cen = prob_cen
self.dec_split = dec_split
self.final_cidx_value = final_cidx_value
self.step_cen_values = step_cen_values
self.suffixes = suffixes
def policy_copy(self, idx: int, aidx: int) -> 'Policy':
policy = self.policy.copy()
policy[idx] = policy[idx].copy()
policy[idx][0] = policy[idx][0].copy()
policy[idx][0][aidx] = policy[idx][0][aidx].copy()
heuristics = self.heuristics.copy()
return Policy(policy, self.ncluster, self.dists, self.prob, self.clustering, self.values, heuristics, self.depth, self.final_cidx_value, self.dists_cen, self.prob_cen, self.dec_split, self.clustering_cen, self.step_cen_values, self.suffixes)
def policy_copy_cent(self, idx: int, jaidx: int) -> 'Policy':
policy = self.policy.copy()
policy[idx] = policy[idx].copy()
policy[idx][1] = policy[idx][1].copy()
policy[idx][1][jaidx] = policy[idx][1][jaidx].copy()
heuristics = self.heuristics.copy()
return Policy(policy, self.ncluster, self.dists, self.prob, self.clustering, self.values, heuristics, self.depth, self.final_cidx_value, self.dists_cen, self.prob_cen, self.dec_split, self.clustering_cen, self.step_cen_values, self.suffixes)
def policy_copy_laststage(self, idx: int, aidx: int) -> 'Policy':
policy = self.policy.copy()
policy[idx] = policy[idx].copy()
policy[idx][0] = policy[idx][0].copy()
policy[idx][0][aidx] = policy[idx][0][aidx].copy()
return Policy(policy, self.ncluster, self.dists, self.prob, self.clustering, self.values, [], self.depth, self.final_cidx_value, self.dists_cen, self.prob_cen, self.dec_split, self.clustering_cen, self.step_cen_values, self.suffixes)
def policy_copy_laststage_cent(self, idx: int) -> 'Policy':
policy = self.policy.copy()
policy[idx] = policy[idx].copy()
policy[idx][1] = policy[idx][1].copy()
heuristics = self.heuristics.copy()
return Policy(policy, self.ncluster, self.dists, self.prob, self.clustering, self.values, heuristics, self.depth, self.final_cidx_value, self.dists_cen, self.prob_cen, self.dec_split, self.clustering_cen, self.step_cen_values, self.suffixes)
def cluster_copy(self) -> 'Policy':
return Policy(self.policy.copy(), self.ncluster.copy(), self.dists.copy(), self.prob.copy(),
self.clustering.copy(), self.values.copy(), self.heuristics, self.depth, self.final_cidx_value, self.dists_cen.copy(), self.prob_cen.copy(), self.dec_split.copy(), self.clustering_cen.copy(), self.step_cen_values.copy(), self.suffixes.copy())
@dataclass
class RSSDAConfig:
"""
Default configuration hyperparameters for the RSSDA solver. May be overwritten by benchmark/drivers.
"""
# Search Horizon
maxh: int
# Heuristic & Search Control
IEmin2: int = 3 # Depth (d) of information-sharing stages for decentralized heuristic computation
maxit: int = 200 # [early heuristic terminal technique] Max iterations per stage expansion
alpha: float = 0.2 # [early heuristic terminal technique] Threshhold for dynamically abandoning heuristics early
heuristic_type: str = "HYBRID" # "QMDP", "POMDP", or "HYBRID"
# Approximation Flags (The "TI" Tiers)
algorithm: str = "exact"
TI1: bool = False # Interleaving Planning/Execution
TI2: bool = False # Progress-based Pruning
TI3: bool = False # Tail Value Approximation
TI4: bool = False # Memory-Bounded Clustering
# TI1 Settings (Interleaving)
score_limit: int = 20
cen_threshold: float = 0.6
sm_temperature: float = 0.6
adaptive_check: int = 10
# TI2 Settings (Pruning)
# iter_limit: Total iteration budget. Per-entity budget B = iter_limit / (nagents + 1).
# IMPORTANT: Set iter_limit such that B >= max expected clusters per entity per stage.
# If an entity has more clusters than B, some may not be explored before pruning.
iter_limit: int = 1000
# TI3 Settings (Tail Approximation)
rec_limit: int = 2
tail_heuristic_type: Optional[str] = None # Defaults to heuristic_type if None
hybrid_r: int = 0
# TI4 Settings (Clustering)
max_clusters: int = 20
# Resource Limits
memory_limit_gb: Optional[float] = 16.0 # Memory limit in GB; None = no limit
memory_check_interval: int = 100 # Check memory every N iterations
# Misc
output: bool = False
def __post_init__(self):
if self.tail_heuristic_type is None:
self.tail_heuristic_type = self.heuristic_type
@dataclass
class TriggerProfile:
"""
Encapsulates the Generalized Trigger Function Phi(C).
Attributes:
sync_actions: Set of joint action indices that trigger synchronization
sync_observations: Set of joint observation indices that trigger synchronization
state_mask: Boolean mask where True indicates a sync state
"""
sync_actions: set = field(default_factory=set)
sync_observations: set = field(default_factory=set)
state_mask: Optional[np.ndarray] = None
class SDecPOMDPModel:
"""
Encapsulates the static definition of the Dec-POMDP problem.
Handles loading from raw inputs or cached sparse/dense structures.
"""
def __init__(self, nagents, nstates, nactions, nobs,
transitions=None, obs=None, rewards=None, init_beliefs=None,
nacts_factor=None, nobs_factor=None,
cached_data=None, sync_states=None,
sync_actions=None, sync_observations=None):
self.nagents = nagents
self.nstates = nstates
self.nactions = nactions
self.nobs = nobs
self.nacts_factor = nacts_factor
self.nobs_factor = nobs_factor
self.init_beliefs = init_beliefs
# Sync trigger logic (Model property)
self.sync_states = sync_states if sync_states is not None else []
# NOTE: `self.sink_state` is retained as a deprecated attribute for
# backward compatibility with benchmarks that reference it, but it is
# no longer consulted internally by RSSDA. The algorithm now flags
# ONLY user-supplied sync_states in the state_mask; any benchmark that
# wants a specific state (e.g. an absorbing terminal state, or a
# co-location "goal" state) to act as a synchronization point must
# include that state explicitly in its sync_states list.
self.sink_state = nstates - 1
state_mask = np.zeros(nstates, dtype=bool)
if self.sync_states:
state_mask[self.sync_states] = True
# 2. Create Unified Profile
self.trigger_profile = TriggerProfile(
sync_actions=set(sync_actions) if sync_actions else set(),
sync_observations=set(sync_observations) if sync_observations else set(),
state_mask=state_mask
)
# --- Data Loading Logic ---
self.use_sparse = False
self.T = None
self.O = None
self.T_csr_list = None
self.O_csr_list = None
self.RA = None
if cached_data is not None:
self._load_from_cache(cached_data)
else:
self._load_from_raw(transitions, obs, rewards)
def _load_from_cache(self, cached_data: Dict[str, Any]) -> None:
# Check for sparse matrices (v5 cache format)
if 'T_csr_list' in cached_data and cached_data.get('sparse', False):
self.T_csr_list = cached_data['T_csr_list']
self.O_csr_list = cached_data['O_csr_list']
self.use_sparse = True
elif 'T_np' in cached_data:
self.T = cached_data['T_np']
self.O = cached_data['O_np']
self.use_sparse = False
else:
raise ValueError("Cache data missing both sparse and dense arrays")
self.RA = cached_data['R_np'].astype(np.float64)
def _load_from_raw(self, transitions: Union[List, Dict], obs: Union[List, Dict], rewards: List) -> None:
# Handle sparse transitions dict (fallback for non-cached loading)
if isinstance(transitions, dict):
trans_size = self.nactions * self.nstates * self.nstates
transitions = [transitions.get(i, 0.0) for i in range(trans_size)]
# Normalize obs to sparse dict format
if not isinstance(obs, dict):
obs = {i: v for i, v in enumerate(obs) if v > 0}
# Build dense arrays (standard initialization)
self.T = np.array(transitions, dtype=np.float64).reshape(self.nactions, self.nstates, self.nstates)
obs_size = self.nactions * self.nstates * self.nobs
obs_dense = [obs.get(i, 0.0) for i in range(obs_size)]
self.O = np.array(obs_dense, dtype=np.float64).reshape(self.nactions, self.nstates, self.nobs)
self.RA = np.array(rewards, dtype=np.float64).reshape(self.nactions, self.nstates)
self.use_sparse = False
class SDecPOMDP:
def __init__(self, model: SDecPOMDPModel, config: RSSDAConfig, qmdp_data=None):
# --- 1. Model Adoption ---
self.model = model
self.nagents = model.nagents
self.nstates = model.nstates
self.nactions = model.nactions
self.nobs = model.nobs
self.nacts_factor = model.nacts_factor
self.nobs_factor = model.nobs_factor
self.init_beliefs = model.init_beliefs
self.sync_states = model.sync_states
self.sink_state = model.sink_state
self.trigger_profile = model.trigger_profile
self.has_sync_triggers = bool(
self.sync_states
or self.trigger_profile.sync_actions
or self.trigger_profile.sync_observations
)
# Performance shortcuts
self.T = model.T
self.O = model.O
self.T_csr_list = model.T_csr_list
self.O_csr_list = model.O_csr_list
self.RA = model.RA
self.use_sparse = model.use_sparse
# --- 2. Solver Configuration ---
self.config = config
self.maxh = config.maxh
self.maxit = config.maxit
self.IEmin2 = config.IEmin2
self.alpha = config.alpha
self.output = config.output
# Approximation Settings
self.algorithm = config.algorithm
self.TI1 = config.TI1
self.TI2 = config.TI2
self.TI3 = config.TI3
self.TI4 = config.TI4
self.score_limit = config.score_limit
self.iter_limit = config.iter_limit
self.rec_limit = config.rec_limit
self.cen_threshold = config.cen_threshold
self.sm_temperature = config.sm_temperature
self.max_clusters = config.max_clusters
self.adaptive_check = config.adaptive_check
self.hybrid_r = config.hybrid_r
self.heuristic_type = config.heuristic_type
self.tail_heuristic_type = config.tail_heuristic_type
# Resource limits
self.memory_limit_gb = config.memory_limit_gb
self.memory_check_interval = config.memory_check_interval
self._last_reported_mem_gb = 0 # Track last reported memory level for 1GB increment logging
self.init_call = True
# --- 3. Pre-computation (Counters & Factors) ---
self.nsq = self.nstates ** 2
self.nso = self.nstates * self.nobs
self.maxa = max(self.nacts_factor)
self.a_prod = [1]*self.nagents
self.o_prod = [1]*self.nagents
for idx in range(self.nagents-1):
self.a_prod[idx+1] = self.a_prod[idx] * self.nacts_factor[idx]
self.o_prod[idx+1] = self.o_prod[idx] * self.nobs_factor[idx]
# Build ctrs mapping
self.ctrs = {1: [1]}
for a in range(self.nagents):
for ctr_fix in list(self.ctrs.keys()):
self.ctrs[ctr_fix] = [c * self.maxa + acta
for c in self.ctrs[ctr_fix]
for acta in range(self.nacts_factor[a])]
for ctr_fix in self.ctrs[1]:
self.ctrs[ctr_fix] = [ctr_fix]
# Build counter-to-ja lookup
self.ctr_to_ja = {}
for full_ctr in self.ctrs[1]:
actions = []
temp = full_ctr
while temp > 1:
actions.append(temp % self.maxa)
temp //= self.maxa
actions.reverse()
ja = sum(actions[i] * self.a_prod[i] for i in range(self.nagents))
self.ctr_to_ja[full_ctr] = ja
# --- 4. Caches & Heuristics ---
self.dec_heuristic = dict()
self.cen_heuristic = dict()
self.newstatedist_dict = dict()
self.terminal_dict = dict()
self.cluster_dict = dict()
self.belief_split_cache = {}
self.cen_V = {}
self.cen_V_hybrid = {}
self.cen_Q = {}
self.qmdp_cache = {}
self.clusterctr_dict = {}
self.terminalMDP_dict = {}
self._terminal_batched = set()
self._os_by_oa_cache = {}
# Initial Belief Setup
self.dist_dict = {int_tuple(self.init_beliefs): 0}
self.dists = [self.init_beliefs]
self.dists_sparse = {}
self.reward_list = (self.RA @ self.init_beliefs).tolist()
# Action Masks
self.valid_actions_per_state = model.valid_actions_per_state if hasattr(model, 'valid_actions_per_state') else None
self.use_action_masks = self.valid_actions_per_state is not None
self.valid_actions_per_position = model.valid_actions_per_position if hasattr(model, 'valid_actions_per_position') else None
self.use_position_action_masks = self.valid_actions_per_position is not None
self._valid_actions_cache = {}
# --- 5. Dynamics Dispatch Setup ---
if self.use_sparse:
self.T_repr = self.T_csr_list
self.O_repr = self.O_csr_list
self.dynamics_fn = fast_dynamics_sparse
else:
self.T_repr = self.T
self.O_repr = self.O
self.dynamics_fn = fast_dynamics
# --- 6. QMDP Initialization ---
if qmdp_data is not None:
self.qmdp_Q = qmdp_data['qmdp_Q'][:self.maxh + 1]
else:
self.qmdp_Q = None
self._solve_qmdp()
# === ACTION MASK OPTIMIZATION METHODS ===
def get_valid_actions_for_belief(self, dist_id: BeliefID) -> Union[range, List[JointActionID]]:
"""
Returns list of valid joint actions for a given belief distribution.
For sparse beliefs, this is the union of valid actions across all
states with non-zero probability. Results are cached for efficiency.
Falls back to all actions if action masks are not provided.
"""
if not self.use_action_masks:
return range(self.nactions)
# Check cache first
cached = self._valid_actions_cache.get(dist_id)
if cached is not None:
return cached
belief = self.dists[dist_id]
valid_set = set()
for s, prob in enumerate(belief):
if prob > EPSILON:
valid_set.update(self.valid_actions_per_state.get(s, range(self.nactions)))
# Convert to sorted list for consistent ordering
valid_list = sorted(valid_set)
if not valid_list:
valid_list = list(range(self.nactions)) # Fallback to all actions
self._valid_actions_cache[dist_id] = valid_list
return valid_list
@staticmethod
def _wkey(w: float) -> int:
return int((10 ** 12 + 39) * w)
# ---------- Centralized DP heuristic ----------
# Compute V_rh(b) where b is referenced by dist_id.
def cen_dp_V(self, rh: int, dist_id: BeliefID) -> float:
if rh <= 0:
return 0.0
key = (rh, dist_id)
v = self.cen_V.get(key)
if v is not None:
return v
# Maximize over valid joint actions (action mask optimization)
best = -math.inf
valid_actions = self.get_valid_actions_for_belief(dist_id)
for ja in valid_actions:
q = self.cen_dp_Q(rh, dist_id, ja)
if q > best:
best = q
self.cen_V[key] = best
# Also populate the heuristic map on the minimal key used elsewhere
self.cen_heuristic[(rh, dist_id, 1)] = best
return best
# Compute Q_rh(b, a) = R(b,a) + sum_o P(o|b,a) V_{rh-1}(b_o')
def cen_dp_Q(self, rh: int, dist_id: BeliefID, ja: JointActionID) -> float:
key = (rh, dist_id, ja)
q = self.cen_Q.get(key)
if q is not None:
return q
r = self.reward_list[dist_id * self.nactions + ja]
# Short-circuit: no future value at horizon 1
if rh <= 1:
self.cen_Q[key] = r
return r
# Only compute belief updates for h > 1
sparse_transitions = self.get_terminal(dist_id, ja)
exp = 0.0
rh_1 = rh - 1
for _, p_o, d_next in sparse_transitions:
exp += p_o * self.cen_dp_V(rh_1, d_next)
q = r + exp
self.cen_Q[key] = q
short_ctr = 2 + ja
self.cen_heuristic[(rh, dist_id, short_ctr)] = q
return q
def cen_dp_V_hybrid(self, rh: int, dist_id: BeliefID, r_depth: int) -> float:
# Wrapper for V that calls the hybrid Q
if rh <= 0: return 0.0
# Check cache to avoid redundant computation
key = (rh, dist_id, r_depth)
v = self.cen_V_hybrid.get(key)
if v is not None:
return v
best = -math.inf
# Action mask optimization: only iterate over valid actions
valid_actions = self.get_valid_actions_for_belief(dist_id)
for ja in valid_actions:
# Call the hybrid Q function
val = self.cen_dp_Q_hybrid(rh, dist_id, ja, r_depth)
if val > best:
best = val
self.cen_V_hybrid[key] = best
return best
def _get_terminal_batched(self, dist: BeliefID, act: JointActionID) -> List[Tuple[ObsID, Prob, BeliefID]]:
# 1. Prepare Inputs
b = self.dists[dist]
# 2. Call Optimized Function (Dispatched)
p_obs_all, joint_unnorm = self.dynamics_fn(
b, self.T_repr, self.O_repr, self.nactions, self.nstates, self.nobs
)
# 3. Process results into Dictionary
for a in range(self.nactions):
p_obs = p_obs_all[a]
sparse_transitions = []
valid_obs_indices = np.flatnonzero(p_obs > EPSILON)
for o in valid_obs_indices:
unnorm_posterior = joint_unnorm[a, o]
_, did = self.get_init(unnorm_posterior)
sparse_transitions.append((o, p_obs[o], did))
self.terminal_dict[(dist, a)] = sparse_transitions
return self.terminal_dict[(dist, act)]
def _get_terminal_single_action(self, dist: BeliefID, act: JointActionID) -> List[Tuple[ObsID, Prob, BeliefID]]:
"""
Compute successors for one requested action.
The all-action batched path is useful when SDec branches repeatedly
through centralized and decentralized action choices. In a no-trigger
Dec-POMDP run, however, RS-SDA* should behave like RS-MAA*: only the
requested action's posterior beliefs should be materialized.
"""
b = self.dists[dist]
if self.use_sparse:
next_states = b @ self.T_repr[act]
active_states = np.flatnonzero(next_states > EPSILON)
p_obs = np.zeros(self.nobs, dtype=np.float64)
joint_unnorm = np.zeros((self.nobs, self.nstates), dtype=np.float64)
O_a = self.O_repr[act]
for s_prime in active_states:
val_ns = next_states[s_prime]
start = O_a.indptr[s_prime]
end = O_a.indptr[s_prime + 1]
for idx in range(start, end):
o = O_a.indices[idx]
prob = val_ns * O_a.data[idx]
joint_unnorm[o, s_prime] = prob
p_obs[o] += prob
else:
next_states = b @ self.T_repr[act]
joint_unnorm = self.O_repr[act] * next_states[:, None]
p_obs = joint_unnorm.sum(axis=0)
joint_unnorm = joint_unnorm.T
sparse_transitions = []
for o in np.flatnonzero(p_obs > EPSILON):
_, did = self.get_init(joint_unnorm[o])
sparse_transitions.append((int(o), float(p_obs[o]), did))
self.terminal_dict[(dist, act)] = sparse_transitions
return sparse_transitions
# r_depth: The number of steps we are allowed to perform full POMDP branching
def cen_dp_Q_hybrid(self, rh: int, dist_id: BeliefID, ja: JointActionID, r_depth: int) -> float:
# 1. Base Case: If we have exhausted our "POMDP budget" (r_depth == 0),
# we stop branching and return the QMDP value for the remaining horizon.
if r_depth <= 0:
belief = self.dists[dist_id]
# self.qmdp_Q is shape (maxh+1, nactions, nstates)
# We want the value for the remaining horizon 'rh'
q_values_h = self.qmdp_Q[rh]
# QMDP Value V(b) = max_a Sum_s b(s) Q(s,a)
# But here we are computing Q(b, ja), so we just need the specific action ja
return np.dot(belief, q_values_h[ja])
# 2. Standard Memoization Check
# We need to include r_depth in the key so we don't mix hybrid vs full values
key = (rh, dist_id, ja, r_depth)
q = self.cen_Q.get(key)
if q is not None:
return q
r_val = self.reward_list[dist_id * self.nactions + ja]
if rh <= 1:
self.cen_Q[key] = r_val
return r_val
# 3. Recursive Step (Standard QPOMDP logic)
sparse_transitions = self.get_terminal(dist_id, ja)
exp = 0.0
rh_1 = rh - 1
# Decrement the depth budget for the next step
next_r = r_depth - 1
for _, p_o, d_next in sparse_transitions:
# Recursively call V, which calls Q_hybrid
exp += p_o * self.cen_dp_V_hybrid(rh_1, d_next, next_r)
q = r_val + exp
self.cen_Q[key] = q
return q
def exact_central_Q_sbt(self, rh: int, dist_id: BeliefID, ja: JointActionID, extra_horizon: int = 0) -> float:
key = ("partQ", rh, dist_id, ja, extra_horizon)
cached = self.cen_Q.get(key)
if cached is not None:
return cached
r = self.reward_list[dist_id * self.nactions + ja]
sparse_transitions = self.get_terminal(dist_id, ja)
exp = 0.0
rh_1 = rh - 1
for _, p_o, d_next in sparse_transitions:
c_id, d_id, p_dec = self.belief_split_by_id(d_next)
# Determine if we're in the "tail" region where we use approximate heuristics
# Both centralized and decentralized components use the same condition for consistency
is_tail = self.TI3 and (rh_1 + extra_horizon <= self.rec_limit)
v_c = 0.0
if c_id != -1:
if is_tail:
v_c = self.get_tail_centralized_value(rh_1 + extra_horizon, c_id)
else:
v_c = self.get_core_centralized_value(rh_1 + extra_horizon, c_id)
v_d = 0.0
if p_dec > 0.0 and d_id != -1:
if is_tail:
v_d = self.get_tail_centralized_value(rh_1 + extra_horizon, d_id)
else:
v_d = self.get_core_centralized_value(rh_1 + extra_horizon, d_id)
exp += p_o * ((1.0 - p_dec) * v_c + p_dec * v_d)
q = r + exp
self.cen_Q[key] = q
return q
def _solve_qmdp(self) -> None:
"""
Solves the underlying MDP using vectorized Value Iteration.
Populates self.qmdp_Q[h][a][s].
"""
self.qmdp_Q = np.zeros((self.maxh + 1, self.nactions, self.nstates))
if self.output:
print(f"Pre-computing Q-MDP Q-values ({'sparse' if self.use_sparse else 'dense'})...", end=" ")
for h in range(1, self.maxh + 1):
# V(s') = max_a Q(s', a)
v_prev = np.max(self.qmdp_Q[h-1], axis=0)
# Expected future value: Sum(T(s,a,s') * V(s'))
if self.use_sparse:
for a in range(self.nactions):
# CSR @ vector
self.qmdp_Q[h, a, :] = self.RA[a, :] + self.T_repr[a] @ v_prev
else:
# einsum: 'asr,r->as'
future_val = np.einsum('asr,r->as', self.T_repr, v_prev)
self.qmdp_Q[h] = self.RA + future_val
if self.output:
print("Done.")
def get_terminalMDP(self, init: BeliefID, h: int, ctr_fix: int = 1) -> float:
"""
Equivalent to SDecPOMDP.get_terminalMDP for the MDP heuristic case.
Returns: max_{ja extending ctr_fix} E_b[Q_MDP(s, ja, h)]
Args:
init: Belief distribution index
h: Horizon (remaining horizon, should be >= 1)
ctr_fix: Partial action counter (1 = maximize over all actions)
Returns:
MDP heuristic value
"""
# Check cache first (matching SDecPOMDP's caching strategy)
if ctr_fix == 1:
cache_key = (init, h)
else:
cache_key = (init, h, ctr_fix)
cached = self.terminalMDP_dict.get(cache_key)
if cached is not None:
return cached
# Compute the value
belief = np.asarray(self.dists[init], dtype=np.float64)
best = -np.inf
for full_ctr in self.ctrs[ctr_fix]:
ja = self.ctr_to_ja[full_ctr]
val = np.dot(belief, self.qmdp_Q[h, ja])
if val > best:
best = val
# Cache and return
self.terminalMDP_dict[cache_key] = best
return best
# ---------- TI2 ----------
def check_progress_pruning(self, pi_c, idx, ctr, aidx, cidx, policyidx):
"""
Calculates the semi-decentralized progress score (prog) and prunes if the
global counter (ctr) exceeds the allowed progress for this policy depth.
Formula:
prog = sigma * L + k * (L/(n+1)) + c + p * (L/(n+1) - |C|)
Where:
- sigma: Completed stages
- k: Completed entities in current stage (Centralized=0, Agent0=1, ...)
- c: Completed clusters for current entity
- p: Cumulative probability of completed clusters
- |C|: Total clusters for current entity
"""
# 1. Constants and Context
current_stage = idx - 1
L = self.iter_limit
n = self.nagents
n_entities = n + 1 # n Agents + 1 Centralized Component
# 2. Base Progress (Fully Completed Stages)
# sigma * L
progress = current_stage * L
# 3. Entity Progress (Completed Entities in Current Stage)
# Expansion Order: Centralized -> Agent 0 -> ... -> Agent n-1
if cidx:
# We are currently expanding the Centralized Component.
# No entities are fully completed in this stage yet.
k = 0
else:
# Centralized is done. Agents 0 to (aidx-1) are done.
# k = 1 (for Centralized) + aidx (for previous agents)
k = 1 + aidx
progress += k * (L / n_entities)
# 4. Cluster Progress (c) and Probability Mass (p)
# c = policyidx (number of clusters already fixed)
c = policyidx
p = 0.0
total_clusters = 0
if cidx:
# --- Centralized Component Logic ---
if current_stage < len(pi_c.dists_cen):
total_clusters = len(pi_c.dists_cen[current_stage])
# Sum probability of all fixed clusters (0 to c-1)
# prob_cen is a simple list [P(b0), P(b1), ...]
if current_stage < len(pi_c.prob_cen):
limit = min(c, len(pi_c.prob_cen[current_stage]))
for i in range(limit):
p += pi_c.prob_cen[current_stage][i]
elif aidx < n:
# --- Decentralized Agent Logic ---
if current_stage < len(pi_c.ncluster):
total_clusters = pi_c.ncluster[current_stage][aidx]
# Calculate p: Sum of joint probabilities consistent with
# agent 'aidx' being in local clusters 0 to c-1.
if current_stage < len(pi_c.prob) and c > 0:
divs, _ = cumprod(pi_c.ncluster[current_stage])
# lists_product2 generates all joint indices where agent 'aidx'
# has a local cluster index in the provided list (range(c)).
relevant_joint_indices = lists_product2(
aidx,
range(c),
pi_c.ncluster[current_stage],
divs,
self.nagents
)
for joint_idx in relevant_joint_indices:
if joint_idx < len(pi_c.prob[current_stage]):
p += pi_c.prob[current_stage][joint_idx]
# else: aidx == n (stage complete, waiting for expansion) - no cluster lookup needed
term_weight = (L / n_entities) - total_clusters
progress += c + (p * term_weight)
if ctr > progress:
return True
return False
# ---------- TI1 ----------
def get_horizon_centralization_scores(self, q_heap: List, top_n: int, threshold: float, temperature: float) -> Tuple[List[bool], List[float]]:
"""
Analyzes the top N nodes using Weighted Majority Voting.
The horizon length that possesses the highest total Softmax probability mass
is selected as the target horizon.
Handles tri-state centralization vectors where:
- True (1.0): Stage is centralized
- False (0.0): Stage is decentralized
- None: Stage is incomplete (excluded from voting for that stage)
For each stage, only nodes with complete (non-None) values contribute to the
weighted average. This prevents incomplete stages from being conflated with
decentralized stages.
"""
# 1. Efficiently peek at top N nodes
top_k_tuples = nsmallest(top_n, q_heap)
if not top_k_tuples:
return [], []
# 2. Extract Data (Value, Vector, Length)
# Store these in parallel lists for efficient numpy masking later
vals = []
vecs = []
lengths = []
for tup in top_k_tuples:
real_val = -tup[0] # Flip negative heap value back to positive
piv = tup[2]
vec = self.centralization_vector(piv)
vals.append(real_val)
vecs.append(vec)
lengths.append(len(vec))
vals = np.array(vals)
lengths = np.array(lengths)
finite_mask = np.isfinite(vals)
if not np.any(finite_mask):
return [], [] # No valid values to analyze
vals = vals[finite_mask]
lengths = lengths[finite_mask]
vecs = [vecs[i] for i in range(len(finite_mask)) if finite_mask[i]]
# 3. GLOBAL Softmax Weighting
# Compute weights for ALL candidates immediately.
# This allows high-value nodes to dominate the selection process.
# Shift values for numerical stability
shift_vals = (vals - np.max(vals)) / temperature
exp_vals = np.exp(shift_vals)
weights = exp_vals / np.sum(exp_vals)
# 4. Determine Best Horizon by Probability Mass
# Sum the weights for every unique length found.