Implements generalized dataflow analysis for static program analysis.
Dataflow Framework:
- Direction: Forward or Backward
- Lattice: Values and operations
- Transfer: Statement → Lattice
- Meet: Lattice meet operation
- Boundary: Entry/exit conditions
from dataclasses import dataclass, field
from typing import Dict, List, Set, Callable, Generic, TypeVar, Optional
from abc import ABC, abstractmethod
from enum import Enum
T = TypeVar('T')
# Lattice definition
class Lattice(ABC, Generic[T]):
"""Abstract lattice"""
@abstractmethod
def top(self) -> T:
"""Top element"""
pass
@abstractmethod
def bottom(self) -> T:
"""Bottom element"""
pass
@abstractmethod
def meet(self, a: T, b: T) -> T:
"""Meet (greatest lower bound)"""
pass
@abstractmethod
def join(self, a: T, b: T) -> T:
"""Join (least upper bound)"""
pass
@abstractmethod
def less_than(self, a: T, b: T) -> bool:
"""Partial order (≤)"""
pass
# Direction
class Direction(Enum):
FORWARD = "forward"
BACKWARD = "backward"
# Dataflow problem
@dataclass
class DataflowProblem(Generic[T]):
"""Dataflow problem specification"""
direction: Direction
lattice: Lattice[T]
transfer: Dict[str, Callable[[T, 'Stmt'], T]]
boundary: T
init: T
# For control flow
cfg: 'ControlFlowGraph'
statements: List['Stmt']
class WorklistSolver(Generic[T]):
"""Worklist algorithm for dataflow"""
def __init__(self, problem: DataflowProblem[T]):
self.problem = problem
self.in_values: Dict[str, T] = {}
self.out_values: Dict[str, T] = {}
self.worklist: Set[str] = set()
def solve(self) -> Dict[str, T]:
"""Solve dataflow equations"""
# Initialize
for stmt_id in self.problem.cfg.nodes:
self.in_values[stmt_id] = self.problem.bottom
self.out_values[stmt_id] = self.problem.bottom
# Add all nodes to worklist
self.worklist = set(self.problem.cfg.nodes)
# Initialize boundary
entry = self.problem.cfg.entry
if self.problem.direction == Direction.FORWARD:
self.out_values[entry] = self.problem.boundary
else:
self.in_values[entry] = self.problem.boundary
# Worklist algorithm
while self.worklist:
node = self.worklist.pop()
# Compute IN
if self.problem.direction == Direction.FORWARD:
in_val = self.compute_forward_in(node)
else:
in_val = self.compute_backward_in(node)
# Check if changed
if in_val == self.in_values[node]:
continue
self.in_values[node] = in_val
# Transfer function
out_val = self.transfer(node, in_val)
if out_val != self.out_values[node]:
self.out_values[node] = out_val
# Add successors to worklist
for succ in self.problem.cfg.successors(node):
self.worklist.add(succ)
return self.in_values if self.problem.direction == Direction.FORWARD else self.out_values
def compute_forward_in(self, node: str) -> T:
"""Compute IN for forward analysis"""
preds = self.problem.cfg.predecessors(node)
if not preds:
return self.problem.boundary
result = self.out_values[preds[0]]
for pred in preds[1:]:
result = self.problem.lattice.meet(result, self.out_values[pred])
return result
def compute_backward_in(self, node: str) -> T:
"""Compute IN for backward analysis"""
succs = self.problem.cfg.successors(node)
if not succs:
return self.problem.boundary
result = self.in_values[succs[0]]
for succ in succs[1:]:
result = self.problem.lattice.meet(result, self.in_values[succ])
return result
def transfer(self, node: str, in_val: T) -> T:
"""Apply transfer function"""
if node in self.problem.transfer:
return self.problem.transfer[node](in_val, self.problem.cfg.stmt(node))
return in_val
# Example: Constant Propagation Lattice
class ConstantPropagationLattice(Lattice[int]):
"""Constant propagation lattice"""
def top(self) -> int:
return float('inf') # NAC - Not A Constant
def bottom(self) -> int:
return float('-inf') # Undefined
def meet(self, a: int, b: int) -> int:
if a == b:
return a
return float('inf') # NAC
def join(self, a: int, b: int) -> int:
if a == b:
return a
return float('inf')
def less_than(self, a: int, b: int) -> bool:
# Simplified
return a == b
# Example: Available Expressions
class AvailableExpressionsLattice(Lattice[Set[tuple]]):
"""Available expressions lattice"""
def top(self) -> Set[tuple]:
return set() # Empty = most precise
def bottom(self) -> Set[tuple]:
return set() # Universal set in practice
def meet(self, a: Set[tuple], b: Set[tuple]) -> Set[tuple]:
return a & b # Intersection
def join(self, a: Set[tuple], b: Set[tuple]) -> Set[tuple]:
return a | b # Union
def less_than(self, a: Set[tuple], b: Set[tuple]) -> bool:
return a.issubset(b)
# Live Variable Analysis
class LiveVariableAnalysis:
"""Live variable analysis"""
def __init__(self, cfg):
self.cfg = cfg
self.problem = DataflowProblem(
direction=Direction.BACKWARD,
lattice=SetLattice(),
transfer={},
boundary=set(),
init=set(),
cfg=cfg
)
def analyze(self) -> Dict[str, Set[str]]:
"""Analyze live variables"""
def transfer(stmt_id: str, stmt: 'Stmt', out_val: Set[str]) -> Set[str]:
in_val = set(out_val)
match stmt:
case Assign(x, e):
in_val.discard(x) # x is overwritten
in_val |= vars(e) # x used in e
case If(cond, _, _):
in_val |= vars(cond)
case _:
pass
return in_val
self.problem.transfer = transfer
solver = WorklistSolver(self.problem)
return solver.solve()
class SetLattice(Lattice[Set]):
"""Lattice of sets"""
def top(self) -> Set:
return set() # Empty set = top (all dead)
def bottom(self) -> Set:
return set() # Universe in practice
def meet(self, a: Set, b: Set) -> Set:
return a & b
def join(self, a: Set, b: Set) -> Set:
return a | b
def less_than(self, a: Set, b: Set) -> bool:
return a.issubset(b)