Bit-batch phasing_hmm helpers: process 8 states per Hvar byte read for 10-15% perf improvement#292
Open
tfenne wants to merge 1 commit into
Open
Bit-batch phasing_hmm helpers: process 8 states per Hvar byte read for 10-15% perf improvement#292tfenne wants to merge 1 commit into
tfenne wants to merge 1 commit into
Conversation
Seven of the inline helper functions in phasing_hmm.h have a tight per-state
loop that calls C->Hvar.get(curr_rel_locus, k) once per iteration. Hvar.get
recomputes the byte address ((row * (n_cols/8)) + (col >> 3)) every call;
the compiler can't hoist any of that arithmetic because col is the loop
variable.
The imputation_hmm kernels already use Hvar.getByte to read 8 states at
once and decode lane-wise via _mm256_sllv_epi32 + _mm256_blendv_ps. We
can't apply that exact lane-wise pattern to phasing_hmm (here each state
produces a full __m256, not a single value), but we can still hoist the
byte read out of the inner state loop:
for (k = 0; k < n_states_full; k += 8, i += 8 * HAP_NUMBER) {
const unsigned char byte = C->Hvar.getByte(curr_rel_locus, k);
for (b = 0; b < 8; ++b) {
const bool ah = (byte >> (7 - b)) & 1;
// ... same per-state work as before ...
}
}
This replaces 8 byte-address computes + 8 byte loads + 8 shift/mask
operations per 8 states with 1 of each, per 8 states. The per-state vector
work (load prob[i], FMA, mul, add to _sum, store prob[i]) is unchanged,
so the output bit pattern of every floating-point computation is preserved.
Trailing states when n_states is not a multiple of 8 fall through to the
original Hvar.get path so the existing semantics for the partial tail are
untouched.
Applied to: INIT_PEAK_HET, INIT_PEAK_HOM, RUN_PEAK_HET, RUN_PEAK_HOM,
COLLAPSE_PEAK_HET, COLLAPSE_PEAK_HOM, IMPUTE_FLAT_HET. The three FLAT_HET
helpers without per-state Hvar lookups (INIT/RUN/COLLAPSE_FLAT_HET) are
untouched.
Wall-time on the same chrA1 chunk used in the previous PR (1.9x diploid,
1984 ref haps, --Kpbwt 2000, default 5+15 burn/main):
* x86 (Granite Rapids): 84s -> 64s (~24% on top of the previous PR)
* arm64 (Apple Silicon): 66s -> 51s (~23% on top of the previous PR)
Cumulative wall-time vs master:
* x86: 129s -> 64s (~50%)
* arm64: 140s -> 51s (~64%)
Output is byte-identical (bcftools view -H | md5sum) across 5 chunks on
both architectures.
tfenne
added a commit
to tfenne/GLIMPSE
that referenced
this pull request
May 8, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
The seven inline helper functions in
phase/src/models/phasing_hmm.h(INIT_PEAK_HET,INIT_PEAK_HOM,RUN_PEAK_HET,RUN_PEAK_HOM,COLLAPSE_PEAK_HET,COLLAPSE_PEAK_HOM,IMPUTE_FLAT_HET) all share a tight per-state inner loop that callsC->Hvar.get(curr_rel_locus, k)once per iteration. EachHvar.getrecomputes the byte address((unsigned long)row) * (n_cols>>3) + (col>>3)from scratch and the compiler can't hoist it becausecolis the loop variable. Theimputation_hmmkernels already avoid this by reading a whole byte viagetByteand processing 8 states per byte. This PR applies the same loop transform to thephasing_hmmhelpers.Output is byte-identical to master. ~10-15% wall-time reduction on its own vs.
master, ~22-24% additional on top of #291 (independent change, gains compose).This PR does not touch any HMM math, any floating-point operation, any vector width, any reduction order, or any random-number draw. The per-state vector work (load
prob[i], FMA, mul, add to_sum, storeprob[i]) is unchanged. The only thing that moves is when the Hvar byte is read: once per 8 states, instead of once per state.The change
Each affected helper goes from:
to:
This replaces 8 byte-address computations + 8 single-byte loads + 8 shift/mask sequences per 8 states with one of each. The trailing fall-through preserves the original behavior when
n_statesis not a multiple of 8.The three
FLAT_HEThelpers (INIT_FLAT_HET,RUN_FLAT_HET,COLLAPSE_FLAT_HET) do no per-state Hvar lookup and are untouched.Performance
Single chrA1 chunk, 1.9x diploid, Nrh=1984,
--Kpbwt 2000, end-to-end:Stacked on top of #291 (
compactSelectionoptimization), gains compose because the two PRs hit different code paths:Tutorial step5 (16 chr22 chunks, NA12878 1x, 1000GP-no-NA12878), arm64:
All 16 tutorial chunks PASS
bcftools view -H | md5sumbyte-identity against master.Why it helps
phasing_hmmcalls these helpers once per common-het / common-hom variant and per iteration. With ~1984 states × ~6 hot helpers × ~21 iterations × thousands of variants per chunk, the hoisted byte-address arithmetic alone is significant. The compiler can't see the access pattern across loop iterations becauseHvar.getreturns through the bitmatrix abstraction; rewriting the loop to read a whole byte makes the structure explicit and gives the compiler a clean, predictable address pattern.The transformation is a 1:1 mirror of what
imputation_hmm::{forward,backward}already does — the main HMM kernels readgetByteand decode lane-wise via_mm256_sllv_epi32+_mm256_blendv_ps. We can't apply that exact lane-wise pattern inphasing_hmm(each state here produces a full__m256, not a single value), but hoisting the byte read still captures most of the benefit.Notes
conditioning_set.{h,cpp},bitmatrix.h, andphasing_hmm.cpp; this one touches onlyphasing_hmm.h). Either can land first.