Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions lib/ModelingToolkitBase/src/systems/problem_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1159,6 +1159,41 @@ end
safe_float(x) = x
safe_float(x::AbstractArray) = isempty(x) ? x : float(x)

"""
PromoteToTunableEltype(observed)

Wraps an `initializeprob` observed function so its output array is promoted to an
eltype compatible with the current tunable parameters. Addresses the case where
the observed function is generated from fully constant RHS (e.g. `initialization_eqs
= [s ~ 0]`): the resulting `create_array(Array, nothing, …, 0, 0)` would otherwise
produce `Vector{Int64}`, which — when downstream `remake` reinstalls it as `u0` —
silently defeats ForwardDiff/Tracker/Measurements promotion of `u0`.

Replaces the previous `safe_float` layer by subsuming it: `promote_type(Int, Float64)
== Float64`, so plain problems still get `Vector{Float64}`; `promote_type(Int,
ForwardDiff.Dual)` yields the Dual type.
"""
struct PromoteToTunableEltype{F}
observed::F
end

function (p::PromoteToTunableEltype)(nlsol)
raw = p.observed(nlsol)
raw isa AbstractArray || return raw
isempty(raw) && return raw
T = promote_type(eltype(raw), _tunable_eltype(parameter_values(nlsol)), Float64)
T === eltype(raw) ? raw : convert(AbstractArray{T}, raw)
end

_tunable_eltype(p::MTKParameters) = isempty(p.tunable) ? Bool : eltype(p.tunable)
function _tunable_eltype(p)
if SciMLStructures.isscimlstructure(p)
tun = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1]
return isempty(tun) ? Bool : eltype(tun)
end
return Bool
end

"""
$(TYPEDSIGNATURES)

Expand Down Expand Up @@ -1258,8 +1293,8 @@ function maybe_build_initialization_problem(
if isempty(solved_unknowns)
initializeprobmap = Returns(nothing)
else
initializeprobmap = u0_constructor ∘ safe_float ∘
getu(initializeprob, solved_unknowns)
initializeprobmap = u0_constructor ∘ PromoteToTunableEltype(
getu(initializeprob, solved_unknowns))
end
else
initializeprobmap = nothing
Expand Down
32 changes: 31 additions & 1 deletion lib/ModelingToolkitBase/test/initializationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -911,7 +911,7 @@ end

@testset "No initialization for variables" begin
@variables x = 1.0
@parameters p = 10.0
@parameters p = 10.0

eqs = [
0 ~ x^2 + 2p * x + 3p
Expand Down Expand Up @@ -2013,3 +2013,33 @@ end
@test !(prob.f.initialization_data.initializeprob isa SCCNonlinearProblem)
end
end

@testset "Issue #4457" begin
@parameters m=1.5 d=9.0
@variables s(t) v(t)

eqs = [
D(s) ~ v
m * D(v) ~ 1 - d * v
]

sys = mtkcompile(System(eqs, t;
name = :model,
initialization_eqs = [s ~ 0, v ~ 0],
))

prob = ODEProblem{true, FullSpecialize}(sys, [], (0.0, 200.0))
sol = solve(prob, Tsit5(); saveat = 0.1)
@test SciMLBase.successful_retcode(sol)

setter = setp_oop(prob, [sys.m, sys.d])

function loss(x)
p = setter(prob, x)
newprob = remake(prob; p)
newsol = solve(newprob, Tsit5(); saveat = 0.1)
sum(abs2, newsol[sys.s])
end

ForwardDiff.gradient(loss, [3.0, 20.0])
end
Loading