Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 47 additions & 5 deletions src/systems/alias_elimination.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,22 @@ function eliminate_perfect_aliases!(state::TearingState)
return nothing
end

"""
$TYPEDSIGNATURES

Pick the target variable for an alias group. Irreducible variables must remain as
unknowns, so one of them is chosen as the target when the group contains any. Otherwise
the variable with the highest `state_priority` wins.
"""
function pick_alias_target(
fullvars::Vector{SymbolicT}, group_vars::Vector{Int}, state_priorities
)
irr_idx = findfirst(v -> isirreducible(fullvars[v]), group_vars)
irr_idx === nothing || return group_vars[irr_idx]
_, target_idx = findmax(Base.Fix1(getindex, state_priorities), group_vars)
return group_vars[target_idx]
end

"""
$TYPEDSIGNATURES

Expand All @@ -64,6 +80,10 @@ function find_perfect_aliases!(

# Not `IntDisjointSet` because we don't want singleton sets for every single variable
alias_groups = DisjointSet{Int}()
# Candidate alias equations `(ieq, v1_idx, v2_idx)`. Removal is decided below once
# each group's target is known: equations with a non-target irreducible endpoint
# must stay so the remaining irreducibles are still constrained to the target.
candidate_eqs = Tuple{Int, Int, Int}[]

for ieq in 1:nsrcs(graph)
snbors = 𝑠neighbors(graph, ieq)
Expand All @@ -82,7 +102,7 @@ function find_perfect_aliases!(
end
_ => continue
end
push!(eqs_to_rm, ieq)
push!(candidate_eqs, (ieq, snbors[1], snbors[2]))
push!(alias_groups, snbors[1])
push!(alias_groups, snbors[2])
union!(alias_groups, snbors[1], snbors[2])
Expand All @@ -96,12 +116,34 @@ function find_perfect_aliases!(
push!(set, var)
end

for aset in values(alias_sets)
_, target_idx = findmax(Base.Fix1(getindex, state.structure.state_priorities), aset)
target = aset[target_idx]
group_target = Dict{Int, Int}()
for (root, group_vars) in alias_sets
group_target[root] = pick_alias_target(fullvars, group_vars, state.structure.state_priorities)
end

# Queue an alias equation for removal only if both endpoints collapse onto the
# target after non-irreducibles are substituted -- i.e. the equation becomes
# `T ~ T`. Any equation with a non-target irreducible endpoint is kept; when the
# other endpoint is a non-irreducible, the existing substitution machinery below
# rewrites the kept equation into `I ~ T` form automatically.
for (ieq, v1, v2) in candidate_eqs
target = group_target[DataStructures.find_root!(alias_groups, v1)]
c1 = isirreducible(fullvars[v1]) ? v1 : target
c2 = isirreducible(fullvars[v2]) ? v2 : target
c1 == c2 && push!(eqs_to_rm, ieq)
end

for (root, group_vars) in alias_sets
target = group_target[root]
state.always_present[target] = true
for v in aset
for v in group_vars
v == target && continue
# Irreducibles other than the target stay as unknowns; only non-irreducibles
# are eliminated in favor of the target.
if isirreducible(fullvars[v])
state.always_present[v] = true
continue
end
push!(vars_to_rm, v)
subs[fullvars[v]] = fullvars[target]
push!(state.additional_observed, fullvars[v] ~ fullvars[target])
Expand Down
20 changes: 20 additions & 0 deletions test/reduction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -353,3 +353,23 @@ ss = mtkcompile(sys)
@mtkcompile sys = System([D(x) ~ 2x, y ~ x], t; state_priorities = [y => 10])
@test isequal(only(unknowns(sys)), y)
end

@testset "Perfect aliases do not eliminate irreducible variables" begin
@variables x(t) y(t)
@variables e(t) [irreducible = true]
@variables c(t) [irreducible = true] d(t) [irreducible = true]
# Two independent alias groups:
# * {x, e} -- one irreducible; the non-irreducible `x` is eliminated as observed
# * {c, d, y} -- two irreducibles + one non-irreducible. `y` is eliminated, both
# irreducibles remain unknowns, bound by the surviving alias
# equation between them.
@mtkcompile sys = System([
D(x) ~ x,
D(c) ~ -c,
e ~ x,
c ~ d,
y ~ c
], t)

@test Set(unknowns(sys)) == Set([e, c, d])
end
Loading