Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions src/systems/alias_elimination.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,97 @@ function find_perfect_aliases!(
return aliases
end

"""
$TYPEDSIGNATURES

Analytically remove any equations in `state` which are only incident on a single variable.
"""
function remove_constant_variables!(state::TearingState; allow_parameter::Bool = true, kwargs...)
StateSelection.complete!(state.structure)
(; additional_observed, original_eqs, fullvars, structure, sys) = state
(; graph, var_to_diff) = structure

diff_to_var = invview(var_to_diff)
eqs = equations(state)
eqs_to_rm = Int[]
vars_to_rm = Int[]
vars_to_rm_set = BitSet()
fullvars_set = Set{SymbolicT}(fullvars)
param_der_subber = SU.Substituter{false}(state.param_derivative_map)

# Preallocated buffer
snbors = Int[]
# Eliminating a variable can cause other equations to be dependent on only one
# un-eliminated variable. This shouldn't realistically run more than 2 times,
# but cap the iteration count at 4 nonetheless.
for _ in 1:4
removed_eq = false
for ieq in 𝑠vertices(graph)
# Check if this equation is incident on exactly 1 un-eliminated variable.
# We could just run `remove_constant_variables!` in a loop, but this
# is faster since we don't repeatedly rebuild structural information.
empty!(snbors)
append!(snbors, 𝑠neighbors(graph, ieq))
setdiff!(snbors, vars_to_rm_set)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this needed to prevent needing to recompute the neighbors after each variable removal? IMO it could use a comment if so.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add a comment (and just spam some more in general)

length(snbors) == 1 || continue
ivar = first(snbors)
# Only eliminate variables which are lowest order derivatives. Eliminating higher
# derivatives requires integrating to also eliminate the lower order variables,
# which is not something we can easily do at the moment.
diff_to_var[ivar] === nothing || continue
eq = eqs[ieq]
var = fullvars[ivar]
# This replicates a piece of `find_eq_solvables!` but calling that function directly:
# 1. requires creating `solvable_graph`, which we would then have to immediately
# discard since we're not populating solvability for every equation.
# 2. still requires us to call `LinearExpander` here. The function does return the
# same value as `b` below, but `eq.rhs - b` is not necessarily `a * var`, since
# e.g. if `eq.rhs` is `(2 + (3x + 4) * t)`, then `b` will be `4t + 2`. However,
# `eq.rhs - b` is `-4t + t*(4 + 3x)`.
# So doing it ourselves is just faster.
lex = Symbolics.LinearExpander(var; strict = true)
a, b, islin = lex(eq.rhs)
Comment thread
oscardssmith marked this conversation as resolved.
islin || continue
# `allow_symbolic = true` since we know this equation (directly or indirectly) only
# depends on `var`. Any variables present in it can only be ones we've already
# eliminated in `vars_to_rm`.
if !MTKTearing._check_allow_symbolic_parameter(
state, a, true, allow_parameter; fullvars_set
)
continue
end
removed_eq = true
push!(eqs_to_rm, ieq)
push!(vars_to_rm, ivar)
push!(vars_to_rm_set, ivar)
# `a` typically is faster to negate, since it is usually a constant or small expression
rhs = b / -a
push!(additional_observed, var ~ rhs)

# Also eliminate all derivatives of this variable.
v = var_to_diff[ivar]
while v !== nothing
# This is identical to how `eq_derivative!` works.
rhs = param_der_subber(
Symbolics.derivative(rhs, get_iv(sys)::SymbolicT; throw_no_derivative = true)
)
push!(additional_observed, default_toterm(fullvars[v]) ~ rhs)
push!(vars_to_rm, v)
push!(vars_to_rm_set, v)
v = var_to_diff[v]
end
end

removed_eq || break
end

old_to_new_eq, old_to_new_var = StateSelection.rm_eqs_vars!(
state, eqs_to_rm, vars_to_rm
)

return length(eqs_to_rm)
end

function alias_elimination!(state::TearingState; fully_determined = true,
print_underconstrained_variables = false, kwargs...)
StateSelection.complete!(state.structure)
Expand Down
1 change: 1 addition & 0 deletions src/systems/systemstructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ function _mtkcompile!(
state = ModelingToolkit.inputs_to_parameters!(state, discrete_inputs, OrderedSet{SymbolicT}())
state = ModelingToolkit.inputs_to_parameters!(state, inputs, outputs)
eliminate_perfect_aliases!(state)
remove_constant_variables!(state; kwargs...)
StateSelection.trivial_tearing!(state)
sys, mm = ModelingToolkit.alias_elimination!(state; fully_determined, kwargs...)
if check_consistency
Expand Down
10 changes: 10 additions & 0 deletions test/reduction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -353,3 +353,13 @@ ss = mtkcompile(sys)
@mtkcompile sys = System([D(x) ~ 2x, y ~ x], t; state_priorities = [y => 10])
@test isequal(only(unknowns(sys)), y)
end

@testset "Constant equations are removed" begin
@variables x(t) y(t) z(t)
@named sys = System([0 ~ 2x + 3t + 4, 0 ~ x * y + 2, 0 ~ D(x) + D(z) + 2z], t)
ts = TearingState(sys)
ModelingToolkit.remove_constant_variables!(ts)
dx = ModelingToolkit.default_toterm(unwrap(D(x)))
@test isequal(ts.additional_observed, [x ~ (3t + 4) / -2, dx ~ (-3//2), y ~ 2 / (-x)])
@test isequal(equations(ts), [0 ~ D(z) + 2z - 3/2])
end
Loading