Skip to content

Bit-batch phasing_hmm helpers: process 8 states per Hvar byte read for 10-15% perf improvement#292

Open
tfenne wants to merge 1 commit into
odelaneau:masterfrom
tfenne:tf_phasing_bitbatch
Open

Bit-batch phasing_hmm helpers: process 8 states per Hvar byte read for 10-15% perf improvement#292
tfenne wants to merge 1 commit into
odelaneau:masterfrom
tfenne:tf_phasing_bitbatch

Conversation

@tfenne

@tfenne tfenne commented May 5, 2026

Copy link
Copy Markdown
Contributor

Summary

The seven inline helper functions in phase/src/models/phasing_hmm.h (INIT_PEAK_HET, INIT_PEAK_HOM, RUN_PEAK_HET, RUN_PEAK_HOM, COLLAPSE_PEAK_HET, COLLAPSE_PEAK_HOM, IMPUTE_FLAT_HET) all share a tight per-state inner loop that calls C->Hvar.get(curr_rel_locus, k) once per iteration. Each Hvar.get recomputes the byte address ((unsigned long)row) * (n_cols>>3) + (col>>3) from scratch and the compiler can't hoist it because col is the loop variable. The imputation_hmm kernels already avoid this by reading a whole byte via getByte and processing 8 states per byte. This PR applies the same loop transform to the phasing_hmm helpers.

Output is byte-identical to master. ~10-15% wall-time reduction on its own vs. master, ~22-24% additional on top of #291 (independent change, gains compose).

This PR does not touch any HMM math, any floating-point operation, any vector width, any reduction order, or any random-number draw. The per-state vector work (load prob[i], FMA, mul, add to _sum, store prob[i]) is unchanged. The only thing that moves is when the Hvar byte is read: once per 8 states, instead of once per state.

The change

Each affected helper goes from:

for (int k = 0, i = 0 ; k != C->n_states ; ++k, i += HAP_NUMBER) {
    const bool ah = C->Hvar.get(curr_rel_locus, k);
    /* ... per-state vector work ... */
}

to:

const unsigned int n_states_full = (C->n_states / 8) * 8;
int k = 0, i = 0;
for ( ; k < n_states_full ; k += 8, i += 8 * HAP_NUMBER) {
    const unsigned char byte = C->Hvar.getByte(curr_rel_locus, k);
    for (int b = 0 ; b < 8 ; ++b) {
        const bool ah = (byte >> (7 - b)) & 1;
        /* ... same per-state vector work, indexed by i + b * HAP_NUMBER ... */
    }
}
for ( ; k < C->n_states ; ++k, i += HAP_NUMBER) {
    /* original Hvar.get path, preserves semantics for trailing bits */
}

This replaces 8 byte-address computations + 8 single-byte loads + 8 shift/mask sequences per 8 states with one of each. The trailing fall-through preserves the original behavior when n_states is not a multiple of 8.

The three FLAT_HET helpers (INIT_FLAT_HET, RUN_FLAT_HET, COLLAPSE_FLAT_HET) do no per-state Hvar lookup and are untouched.

Performance

Single chrA1 chunk, 1.9x diploid, Nrh=1984, --Kpbwt 2000, end-to-end:

Master This PR alone
x86 (Granite Rapids, native AVX2) 129s 110s (~15%)
arm64 (Apple Silicon, AVX2 → NEON via simde) 140s 126s (~10%)

Stacked on top of #291 (compactSelection optimization), gains compose because the two PRs hit different code paths:

Master PR291 alone PR291 + this PR
x86 chrA1 129s 84s 64s (~50%)
arm64 chrA1 140s 66s 51s (~64%)

Tutorial step5 (16 chr22 chunks, NA12878 1x, 1000GP-no-NA12878), arm64:

PR291 alone + this PR
total 185s 155s (~16% additional)

All 16 tutorial chunks PASS bcftools view -H | md5sum byte-identity against master.

Why it helps

phasing_hmm calls these helpers once per common-het / common-hom variant and per iteration. With ~1984 states × ~6 hot helpers × ~21 iterations × thousands of variants per chunk, the hoisted byte-address arithmetic alone is significant. The compiler can't see the access pattern across loop iterations because Hvar.get returns through the bitmatrix abstraction; rewriting the loop to read a whole byte makes the structure explicit and gives the compiler a clean, predictable address pattern.

The transformation is a 1:1 mirror of what imputation_hmm::{forward,backward} already does — the main HMM kernels read getByte and decode lane-wise via _mm256_sllv_epi32 + _mm256_blendv_ps. We can't apply that exact lane-wise pattern in phasing_hmm (each state here produces a full __m256, not a single value), but hoisting the byte read still captures most of the benefit.

Notes

  • Independent of Optimize compactSelection: 25-50% wall-time reduction in GLIMPSE2_phase, byte-identical results #291. The two PRs touch disjoint files (#PR1 touches conditioning_set.{h,cpp}, bitmatrix.h, and phasing_hmm.cpp; this one touches only phasing_hmm.h). Either can land first.
  • No new tests — consistent with the project's existing approach (end-to-end tutorial scripts as ground truth). The byte-identity check is the validation I relied on.
  • Builds clean on Linux/x86_64 (gcc 11, native AVX2) and macOS/arm64 (clang, simde).

Seven of the inline helper functions in phasing_hmm.h have a tight per-state
loop that calls C->Hvar.get(curr_rel_locus, k) once per iteration. Hvar.get
recomputes the byte address ((row * (n_cols/8)) + (col >> 3)) every call;
the compiler can't hoist any of that arithmetic because col is the loop
variable.

The imputation_hmm kernels already use Hvar.getByte to read 8 states at
once and decode lane-wise via _mm256_sllv_epi32 + _mm256_blendv_ps. We
can't apply that exact lane-wise pattern to phasing_hmm (here each state
produces a full __m256, not a single value), but we can still hoist the
byte read out of the inner state loop:

    for (k = 0; k < n_states_full; k += 8, i += 8 * HAP_NUMBER) {
        const unsigned char byte = C->Hvar.getByte(curr_rel_locus, k);
        for (b = 0; b < 8; ++b) {
            const bool ah = (byte >> (7 - b)) & 1;
            // ... same per-state work as before ...
        }
    }

This replaces 8 byte-address computes + 8 byte loads + 8 shift/mask
operations per 8 states with 1 of each, per 8 states. The per-state vector
work (load prob[i], FMA, mul, add to _sum, store prob[i]) is unchanged,
so the output bit pattern of every floating-point computation is preserved.

Trailing states when n_states is not a multiple of 8 fall through to the
original Hvar.get path so the existing semantics for the partial tail are
untouched.

Applied to: INIT_PEAK_HET, INIT_PEAK_HOM, RUN_PEAK_HET, RUN_PEAK_HOM,
COLLAPSE_PEAK_HET, COLLAPSE_PEAK_HOM, IMPUTE_FLAT_HET. The three FLAT_HET
helpers without per-state Hvar lookups (INIT/RUN/COLLAPSE_FLAT_HET) are
untouched.

Wall-time on the same chrA1 chunk used in the previous PR (1.9x diploid,
1984 ref haps, --Kpbwt 2000, default 5+15 burn/main):

  * x86 (Granite Rapids):  84s -> 64s  (~24% on top of the previous PR)
  * arm64 (Apple Silicon): 66s -> 51s  (~23% on top of the previous PR)

Cumulative wall-time vs master:

  * x86:   129s -> 64s  (~50%)
  * arm64: 140s -> 51s  (~64%)

Output is byte-identical (bcftools view -H | md5sum) across 5 chunks on
both architectures.
@srubinacci srubinacci self-requested a review May 6, 2026 06:12
tfenne added a commit to tfenne/GLIMPSE that referenced this pull request May 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant