Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions lib/ModelingToolkitBase/src/problems/odeproblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,19 @@ Base.@nospecializeinfer @fallback_iip_specialize function SciMLBase.ODEProblem{i
check_compatibility && check_compatible_system(ODEProblem, sys)

_iip = resolve_iip(iip, op)
f, u0,
p = process_SciMLProblem(
ODEFunction{_iip, spec}, sys, op;
t = tspan !== nothing ? tspan[1] : tspan, check_length, eval_expression,
eval_module, expression, check_compatibility, kwargs...
)
if _iip === true
f, u0, p = process_SciMLProblem(
ODEFunction{true, spec}, sys, op;
t = tspan !== nothing ? tspan[1] : tspan, check_length, eval_expression,
eval_module, expression, check_compatibility, kwargs...
)
else
f, u0, p = process_SciMLProblem(
ODEFunction{false, spec}, sys, op;
t = tspan !== nothing ? tspan[1] : tspan, check_length, eval_expression,
eval_module, expression, check_compatibility, kwargs...
)
end

kwargs = process_kwargs(
sys; expression, callback, eval_expression, eval_module, op, _skip_events, tspan, kwargs...
Expand Down
7 changes: 5 additions & 2 deletions lib/ModelingToolkitBase/src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ end

safe_vec(@nospecialize(x)) = x isa SymbolicT ? [x] : vec(x::Array{SymbolicT})

system(a::AffectSystem) = a.system
system(a::AffectSystem) = a.system::System
discretes(a::AffectSystem) = a.discretes
unknowns(a::AffectSystem) = a.unknowns
parameters(a::AffectSystem) = a.parameters
Expand Down Expand Up @@ -655,7 +655,10 @@ function namespace_affects(affect::AffectSystem, s)
# called `affectsys` for further namespacing
affsys = rename(affsys, nameof(s))
affsys = toggle_namespacing(affsys, true)
affsys = System(Equation[], get_iv(affsys); systems = [affsys], name = :affectsys)
affsys = System(
Equation[], get_iv(affsys)::SymbolicT, SymbolicT[], SymbolicT[];
systems = System[affsys], name = :affectsys
)
affsys = complete(affsys)
@set! affsys.tearing_state = old_ts
return AffectSystem(
Expand Down
12 changes: 6 additions & 6 deletions lib/ModelingToolkitBase/src/systems/imperative_affect.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,12 @@ function Base.:(==)(a1::ImperativeAffect, a2::ImperativeAffect)
end

function Base.hash(a::ImperativeAffect, s::UInt)
s = hash(a.f, s)
s = hash(a.obs, s)
s = hash(a.obs_syms, s)
s = hash(a.modified, s)
s = hash(a.mod_syms, s)
return hash(a.ctx, s)
s = hash(a.f, s)::UInt
s = hash(a.obs, s)::UInt
s = hash(a.obs_syms, s)::UInt
s = hash(a.modified, s)::UInt
s = hash(a.mod_syms, s)::UInt
return hash(a.ctx, s)::UInt
end

namespace_affects(af::ImperativeAffect, s) = namespace_affect(af, s)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -633,8 +633,8 @@ function SciMLBase.remake_initialization_data(
u0_constructor = get_u0_constructor(identity, typeof(newu0), floatT, false)
p_constructor = get_p_constructor(identity, typeof(newu0), floatT)
kws = maybe_build_initialization_problem(
sys, SciMLBase.isinplace(odefn), op, t0, guesses;
time_dependent_init, use_scc, initialization_eqs, floatT, fast_path = true,
sys, Val{SciMLBase.isinplace(odefn)}(), op, t0, guesses, floatT;
time_dependent_init, use_scc, initialization_eqs, fast_path = true,
u0_constructor, p_constructor, allow_incomplete = true, check_units = false,
missing_guess_value = meta.missing_guess_value
)
Expand Down
75 changes: 38 additions & 37 deletions lib/ModelingToolkitBase/src/systems/problem_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1088,16 +1088,15 @@ struct GetUpdatedU0{GG, GIU}
get_initial_unknowns::GIU
end

function GetUpdatedU0(sys::AbstractSystem, initprob::SciMLBase.AbstractNonlinearProblem, op::AbstractDict)
@nospecialize initprob
function GetUpdatedU0(sys::AbstractSystem, initsys::AbstractSystem, op::AbstractDict)
dvs = unknowns(sys)
eqs = equations(sys)
guessvars = trues(length(dvs))
for (i, var) in enumerate(dvs)
varval = get(op, var, COMMON_NOTHING)
guessvars[i] = varval === COMMON_NOTHING || !SU.isconst(varval)
end
get_guessvars = getu(initprob, dvs[guessvars])
get_guessvars = getu(initsys, dvs[guessvars])
get_initial_unknowns = getu(sys, Initial.(dvs))
return GetUpdatedU0(guessvars, get_guessvars, get_initial_unknowns)
end
Expand Down Expand Up @@ -1170,14 +1169,14 @@ constructed is in implicit DAE form (`DAEProblem`). All other keyword arguments
to `InitializationProblem`.
"""
function maybe_build_initialization_problem(
sys::AbstractSystem, iip, op::AbstractDict, t, guesses;
sys::AbstractSystem, ::Val{iip}, op::SymmapT, t, guesses, ::Type{floatT};
time_dependent_init = is_time_dependent(sys), u0_constructor = identity,
p_constructor = identity, floatT = Float64, initialization_eqs = [],
p_constructor = identity, initialization_eqs = [],
use_scc = true, eval_expression = false, eval_module = @__MODULE__,
missing_guess_value = default_missing_guess_value(),
# Intercept `expression` because we don't support it here yet
implicit_dae = false, is_steadystateprob = false, expression = Val{false}, kwargs...
)
) where {iip, floatT}
guesses = merge(ModelingToolkitBase.guesses(sys), todict(guesses))

if t === nothing && is_time_dependent(sys)
Expand All @@ -1190,6 +1189,7 @@ function maybe_build_initialization_problem(
use_scc, u0_constructor, p_constructor, eval_expression, eval_module,
missing_guess_value, is_steadystateprob, kwargs...
)
initsys = initializeprob.f.sys::System
needs_remake = false
_u0 = state_values(initializeprob)
if _u0 !== nothing
Expand Down Expand Up @@ -1236,7 +1236,7 @@ function maybe_build_initialization_problem(
end

get_initial_unknowns = if time_dependent_init
GetUpdatedU0(sys, initializeprob, op)
GetUpdatedU0(sys, initsys, op)
else
nothing
end
Expand All @@ -1246,7 +1246,7 @@ function maybe_build_initialization_problem(
Vector{Equation}(initialization_eqs),
use_scc, time_dependent_init,
ReconstructInitializeprob(
sys, initializeprob.f.sys; u0_constructor,
sys, initsys; u0_constructor,
p_constructor, eval_expression, eval_module, is_steadystateprob, kwargs...
),
get_initial_unknowns, SetInitialUnknowns(sys), missing_guess_value
Expand Down Expand Up @@ -1274,7 +1274,7 @@ function maybe_build_initialization_problem(
initializeprobpmap = nothing
else
initializeprobpmap = construct_initializeprobpmap(
sys, initializeprob.f.sys; p_constructor, eval_expression, eval_module, kwargs...
sys, initsys; p_constructor, eval_expression, eval_module, kwargs...
)
end

Expand Down Expand Up @@ -1302,7 +1302,6 @@ function maybe_build_initialization_problem(
end
end
if implicit_dae
initsys = initializeprob.f.sys
for v in unknowns(sys)
v = Differential(get_iv(sys))(v)
ttv = default_toterm(v)
Expand All @@ -1328,10 +1327,10 @@ function maybe_build_initialization_problem(
end
missingvars = collect(missingvars)

for (i, v) in enumerate(unknowns(initializeprob.f.sys))
for (i, v) in enumerate(unknowns(initsys))
write_possibly_indexed_array!(temp_op, v, SConst(_u0[i]), COMMON_NOTHING)
end
add_observed!(initializeprob.f.sys, temp_op)
add_observed!(initsys, temp_op)
left_merge!(temp_op, ModelingToolkitBase.guesses(sys))
subber = Symbolics.FixpointSubstituter{true}(AADSubWrapper(temp_op))
for p in missingvars
Expand Down Expand Up @@ -1504,13 +1503,28 @@ $PROBLEM_INTERNAL_KWARGS

All other keyword arguments are passed as-is to `constructor`.
"""
function process_SciMLProblem(
constructor, sys::AbstractSystem, op;
Base.@nospecializeinfer function process_SciMLProblem(
::Type{constructor}, sys::AbstractSystem, @nospecialize(op);
u0_eltype = nothing, u0_constructor = identity, p_constructor = identity,
symbolic_u0 = false, kwargs...
) where {constructor}
u0Type = pType = typeof(op)
op = operating_point_preprocess(sys, op)
floatT = calculate_float_type(op, u0Type)
floatT = something(u0_eltype, floatT)
u0_constructor = get_u0_constructor(u0_constructor, u0Type, floatT, symbolic_u0)
p_constructor = get_p_constructor(p_constructor, pType, floatT)

__process_SciMLProblem(constructor, sys, op, floatT, u0Type; u0_constructor, p_constructor, symbolic_u0, kwargs...)
end

function __process_SciMLProblem(
::Type{constructor}, sys::AbstractSystem, op::AnyDict, ::Type{floatT}, ::Type{u0Type};
build_initializeprob = supports_initialization(sys),
implicit_dae = false, t = nothing, guesses = AnyDict(),
warn_initialize_determined = true, initialization_eqs = [],
eval_expression = false, eval_module = @__MODULE__, fully_determined = nothing,
check_initialization_units = false, u0_eltype = nothing, tofloat = true,
check_initialization_units = false, tofloat = true,
u0_constructor = identity, p_constructor = identity,
check_length = true, symbolic_u0 = false, warn_cyclic_dependency = false,
circular_dependency_max_cycle_length = length(all_symbols(sys)),
Expand All @@ -1519,25 +1533,19 @@ function process_SciMLProblem(
algebraic_only = false, missing_guess_value = default_missing_guess_value(),
allow_incomplete = false, is_initializeprob = false, is_steadystateprob = false,
return_operating_point = false, kwargs...
)
) where {constructor, floatT, u0Type}
dvs = unknowns(sys)
ps = parameters(sys; initial_parameters = true)
iv = has_iv(sys) ? get_iv(sys) : nothing
eqs = equations(sys)

check_array_equations_unknowns(eqs, dvs)

u0Type = pType = typeof(op)

op = operating_point_preprocess(sys, op)
floatT = calculate_float_type(op, u0Type)
u0_eltype = something(u0_eltype, floatT)

op = build_operating_point(sys, op; fast_path = true)

check_inputmap_keys(sys, op)

op = getmetadata(sys, ProblemConstructionHook, identity)(op)
op = getmetadata(sys, ProblemConstructionHook, identity)(op)::SymmapT

kwargs = NamedTuple(kwargs)

Expand All @@ -1552,18 +1560,15 @@ function process_SciMLProblem(
add_observed_equations!(op, obs, bindings(sys))
end

u0_constructor = get_u0_constructor(u0_constructor, u0Type, u0_eltype, symbolic_u0)
p_constructor = get_p_constructor(p_constructor, pType, floatT)

if build_initializeprob
kws = maybe_build_initialization_problem(
sys, constructor <: SciMLBase.AbstractSciMLFunction{true},
op, t, guesses; initsys_mtkcompile_kwargs,
sys, Val{constructor <: SciMLBase.AbstractSciMLFunction{true}}(),
op, t, guesses, floatT; initsys_mtkcompile_kwargs,
warn_initialize_determined, initialization_eqs,
eval_expression, eval_module, fully_determined,
warn_cyclic_dependency, check_units = check_initialization_units,
circular_dependency_max_cycle_length, circular_dependency_max_cycles, use_scc,
algebraic_only, allow_incomplete, u0_constructor, p_constructor, floatT,
algebraic_only, allow_incomplete, u0_constructor, p_constructor,
time_dependent_init, missing_guess_value, is_steadystateprob, implicit_dae,
kwargs...
)
Expand Down Expand Up @@ -1606,13 +1611,13 @@ function process_SciMLProblem(

if is_initializeprob
u0 = varmap_to_vars(
op, dvs; buffer_eltype = u0_eltype, container_type = u0Type,
op, dvs; buffer_eltype = floatT, container_type = u0Type,
allow_symbolic = symbolic_u0, is_initializeprob, substitution_limit,
missing_values = missing_guess_value
)
else
u0 = varmap_to_vars(
op, dvs; buffer_eltype = u0_eltype, container_type = u0Type,
op, dvs; buffer_eltype = floatT, container_type = u0Type,
allow_symbolic = symbolic_u0, is_initializeprob, substitution_limit
)
end
Expand All @@ -1638,13 +1643,9 @@ function process_SciMLProblem(
end

if is_split(sys)
# `pType` is usually `Dict` when the user passes key-value pairs.
if !(pType <: AbstractArray)
pType = Array
end
p = MTKParameters(sys, op; floatT = floatT, p_constructor, fast_path = true)
else
p = p_constructor(varmap_to_vars(op, ps; tofloat, container_type = pType))
p = p_constructor(varmap_to_vars(op, ps; tofloat, container_type = u0Type))
end

if implicit_dae
Expand Down Expand Up @@ -1672,7 +1673,7 @@ function process_SciMLProblem(
end

f = constructor(
sys; u0 = u0, p = p,
sys; u0 = u0, p = p, t = t,
eval_expression = eval_expression,
eval_module = eval_module,
kwargs...
Expand Down
Loading