As in the title and inspired by #485, see e.g.
|
function step( |
|
lf::DefaultLeapfrog{FT,T}, |
|
h::Hamiltonian, |
|
z::P, |
|
n_steps::Int=1; |
|
fwd::Bool=n_steps > 0, # simulate hamiltonian backward when n_steps < 0 |
|
full_trajectory::Val{FullTraj}=Val(false), |
|
) where {FT<:AbstractFloat,T<:AbstractScalarOrVec{FT},P<:PhasePoint,FullTraj} |
|
n_steps = abs(n_steps) # to support `n_steps < 0` cases |
|
|
|
ϵ = fwd ? step_size(lf) : -step_size(lf) |
|
ϵ = ϵ' |
|
|
|
res = FullTraj ? Vector{P}(undef, n_steps) : nothing |
|
|
|
(; θ, r) = z |
|
(; value, gradient) = z.ℓπ |
|
for i in 1:n_steps |
|
# Tempering |
|
r = temper(lf, r, (i=i, is_half=true), n_steps) |
|
# Take a half leapfrog step for momentum variable |
|
r = r - ϵ / 2 .* gradient |
|
# Take a full leapfrog step for position variable |
|
∇r = ∂H∂r(h, r) |
|
θ = θ + ϵ .* ∇r |
|
# Take a half leapfrog step for momentum variable |
|
(; value, gradient) = ∂H∂θ(h, θ) |
|
r = r - ϵ / 2 .* gradient |
|
# Tempering |
|
r = temper(lf, r, (i=i, is_half=false), n_steps) |
|
# Create a new phase point by caching the logdensity and gradient |
|
z = phasepoint(h, θ, r; ℓπ=DualValue(value, gradient)) |
|
# Update result |
|
if !isnothing(res) |
|
res[i] = z |
|
end |
|
if !isfinite(z) |
|
# Remove undef |
|
if !isnothing(res) |
|
resize!(res, i) |
|
end |
|
break |
|
end |
|
end |
|
return if FullTraj |
|
res |
|
else |
|
z |
|
end |
|
end |
As in the title and inspired by #485, see e.g.
AdvancedHMC.jl/src/integrator.jl
Lines 216 to 265 in 6bc0c74