Skip to content

Add Qwen3.5-4B math recipes and make Qwen3 dense recipes run on the transformers-5.8 / torch-2.11 stack#21

Open
iamziyuzhao wants to merge 11 commits into
Infini-AI-Lab:devfrom
iamziyuzhao:add-qwen3.5-recipes-and-qwen3-dense-fix
Open

Add Qwen3.5-4B math recipes and make Qwen3 dense recipes run on the transformers-5.8 / torch-2.11 stack#21
iamziyuzhao wants to merge 11 commits into
Infini-AI-Lab:devfrom
iamziyuzhao:add-qwen3.5-recipes-and-qwen3-dense-fix

Conversation

@iamziyuzhao

@iamziyuzhao iamziyuzhao commented Jun 22, 2026

Copy link
Copy Markdown

Summary

This PR makes AstraFlow run both the Qwen3 dense math recipes and new
Qwen3.5 math recipes on the current transformers-5 / torch-2.11 stack.

It supersedes #11 (which added the Qwen3.5 recipe but left the existing dense
Qwen3 math recipes broken on transformers ≥ 5). Targets dev.

What changed

1. Make Qwen3 dense recipes run on transformers ≥ 5 (fsdp_engine.py, recipes, cli_args.py)

  • The packed/varlen training forward now passes attention_mask=None for all
    archs. The old dict(full_attention=None, sliding_attention=None) form is a
    transformers-4.x relic: on transformers ≥ 5 a dense model (qwen3/qwen2) treats
    that dict as a precomputed mask, skips creation, and crashes. None is
    correct — masking is driven by cu_seqlens + position_ids. (Verified more
    correct than the old form even where it didn't crash: under SDPA the old dict
    path silently lost block-diagonal separation.)
  • Dense recipes (qwen3-1.7b-m2po-2gpus-{full,delta}, qwen3-8b-m2po-{full,delta})
    and the cli_args.py default now set attn_impl: kernels-community/flash-attn2
    a prebuilt FlashAttention-2 kernel pulled from the HF kernels hub (build
    variant ABI-matched to torch 2.11+cu130). The literal flash_attention_2 is
    avoided because it loads the local flash-attn wheel, which is ABI-broken vs
    torch 2.11+cu130; sdpa/eager remain valid choices. (FA2-kernels and
    sdpa were confirmed eval-equivalent within single-sample noise; FA2 is the
    faster default.)

2. Add Qwen3.5-4B math recipes + enablement

  • New examples/math/qwen3.5-4b-m2po-{full,delta}/ recipes (M2PO, ctx 8k,
    DeepScaleR, math_verify), with a README documenting the validated stack.
    Qwen3.5 is a hybrid Gated-DeltaNet + attention checkpoint; the trainer uses
    attn_impl: sdpa (GDN linear-attn via fla kernels, full-attn via sdpa) and
    inference uses attention_backend: flashinfer.
  • model.py: register qwen3_5 in VALID_VISION_MODELS (the checkpoint ships
    as Qwen3_5ForConditionalGeneration, loaded via the ImageTextToText path;
    trained text-only).

Validation

Qwen3-1.7B dense, transformers 5.8.1: full + delta trained 300 steps, eval
@0/100/200/300. math500 avg@k: full 73.8 → 79.9, delta 72.4 → 80.2; overall
avg@k: full 31.3 → 38.8, delta 31.1 → 38.3. importance_weight ≈ 1.0 every step
(packed forward correct); zero crashes; resumed cleanly from an external kill.

Qwen3.5-4B (from #11): trains end-to-end and eval rises steadily.

Qwen3.5-4B (overall) step 0 step 80 Δ
avg@k 47.8% 57.4% +9.6
pass@k 56.5% 67.4% +10.9

Qwen3.5-4B re-validated on the pinned 0.5.13.post1 release: both full
and delta re-run end-to-end on sglang==0.5.13.post1 — training completes,
full (shard_copy) and delta (~7× compressed) weight transfer both work, and
eval holds at baseline (~49–51% overall avg@k over a short run). No regression
vs the predecessor git build the table above came from.

Validated environment (Qwen3.5)

torch 2.11.0+cu130, transformers 5.8.1, kernels 0.14.1, SGLang 0.5.13.post1
(published release with qwen3_5 support) served via attention_backend: flashinfer, flash-linear-attention 0.5.0, flashinfer-python 0.6.12,
attn_impl=sdpa.

pyproject.toml pins the full validated stack: transformers==5.8.1 (with
kernels>=0.14,<0.15), torch==2.11.0, flash-linear-attention==0.5.0 (the
fast Qwen3.5 GDN kernels; optional, with a pure-torch fallback), and
sglang==0.5.13.post1 — the published release that ships qwen3_5 (the
older 0.5.12.post1 predated it), which pulls flashinfer 0.6.12 in automatically.
uv pip install -e ".[sglang]" resolves cleanly to exactly the validated
versions.

haizhongzheng and others added 2 commits June 22, 2026 10:06
New recipes examples/math/qwen3.5-4b-m2po-{full,delta} for training Qwen3.5
(dense, hybrid Gated-DeltaNet text backbone; model_type qwen3_5) with M2PO on
AstraFlow, mirroring the existing qwen3-8b-m2po recipe structure. Trained
text-only for math (the checkpoint ships as Qwen3_5ForConditionalGeneration).

Verified end-to-end on NVIDIA L40 (Ada, sm_89): a full run trained 86+ steps
with no crash and steadily rising eval — overall avg@k 47.8% -> 57.4% (+9.6) and
pass@k 56.5% -> 67.4% over the first 80 steps (AIME24/AIME25/AMC/Minerva/MATH500,
eval every 10 steps).

Minimal framework changes for Qwen3.5 / transformers>=5 compatibility:
- model.py: register qwen3_5 + is_qwen3_5_model()
- fsdp_engine.py: pass attention_mask=None for qwen3_5 (transformers>=5
  create_causal_mask calls .ndim; the old dict form raised AttributeError)
- fsdp/__init__.py: normalize _no_split_modules set->list (qwen3_5 exposes a set)
- rlvr.py: unwrap BatchEncoding from apply_chat_template (transformers>=5)

Recipes use the standard packed training forward. attention_backend=flashinfer
(fa3 dispatches a Hopper-only kernel that fails on Ada/L40 for the GDN path;
flashinfer + triton both verified); max_running_requests=32, mem_fraction_static=0.7
on inference and FSDP dp=4 + max_tokens_per_mb=8192 on the trainer to fit 44GB L40.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The version bump (transformers 5.8.1, torch 2.11+cu130, sglang dev) broke the
existing plain-Qwen3 math recipes (qwen3-1.7b, qwen3-8b). Two fixes:

1. fsdp_engine.py: always pass attention_mask=None for the packed/varlen forward.
   The old dict(full_attention=None, sliding_attention=None) form is a
   transformers-4.x relic; on transformers>=5 a dense model (qwen3/qwen2) treats
   that dict as a *precomputed* mask, skips creation, and crashes ('dict' object
   has no attribute 'ndim'). None is correct for all archs (dense, moe, vl,
   qwen3.5/GDN) — masking is driven by cu_seqlens + position_ids. Subsumes the
   prior qwen3_5/moe/vl special-case (drop now-unused imports).

2. recipes + cli_args: set attn_impl: sdpa in actor+ref for qwen3-1.7b and
   qwen3-8b recipes (flash_attn is ABI-broken vs torch 2.11+cu130 -> import crash;
   default was flash_attention_2). Expand attn_impl choices to sdpa/eager.

Verified end-to-end: qwen3-1.7b full + delta train cleanly to step 100 on the
bumped stack; math500 avg@k rises 73.0->77.0 (full) / 72.2->78.8 (delta);
importance_weight=1.0000 (packed forward correct); zero crashes.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
- Remove the unused is_qwen3_5_model() helper (its only caller, the mask-branch
  special-case, went away when the packed-forward mask was made unconditional).
- Note in model.py why qwen3_5 is in VALID_VISION_MODELS (the checkpoint ships
  as Qwen3_5ForConditionalGeneration -> loaded via the ImageTextToText path).
- Simplify the packed-forward attention_mask comment in fsdp_engine.py.

No behavior change.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@iamziyuzhao iamziyuzhao force-pushed the add-qwen3.5-recipes-and-qwen3-dense-fix branch from b294f1e to 86c8ffb Compare June 22, 2026 14:56
Document the validated runtime stack (transformers 5.8.1, kernels 0.14.1, SGLang
dev with qwen3_5 + TritonGDNKernel, fla 0.5.0, flashinfer 0.6.11.post1, torch
2.11.0+cu130, attn_impl sdpa), GPU layout, run commands, and the validated eval
(overall avg@k 47.8 -> 57.4, pass@k 56.5 -> 67.4 over 80 steps).

The default pyproject pins load qwen3_5, but the validated stack is installed out
of band; a pin bump is deferred to a separate, tested PR (the validated SGLang is
a dev build, not a clean release).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@iamziyuzhao iamziyuzhao force-pushed the add-qwen3.5-recipes-and-qwen3-dense-fix branch from 86c8ffb to 56ea705 Compare June 22, 2026 15:05
haizhongzheng and others added 7 commits June 22, 2026 11:22
Bump the pyproject transformers pin (core dep + uv override) from 5.6.1 to the
5.8.1 that the Qwen3.5 and Qwen3-dense math recipes were validated on, and move
the coupled kernels constraint from <0.13 to >=0.14,<0.15 (what transformers
5.8.1 was validated against). torch is already 2.11.0; flashinfer comes in
transitively via sglang. sglang stays pinned at the published 0.5.12.post1 — the
Qwen3.5 inference path was validated on an sglang dev build that ships qwen3_5,
as noted in the recipe README.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
….8.1)

The published sglang 0.5.12.post1 predates qwen3_5 and pins transformers 5.6.x,
so it is incompatible with the transformers==5.8.1 the recipes require. Pin
sglang to the validated main-branch build (sgl-project/sglang @ 373cadc9): it
ships qwen3_5 + TritonGDNKernel and itself requires transformers==5.8.1,
flashinfer 0.6.11.post1, torch 2.11.0, kernels<0.15 -- all matching the
validated env. Update the [tool.uv] comments to match.

Verified: every pin/override matches the working astraflow35 env (transformers
5.8.1, kernels 0.14.1, outlines-core 0.1.26, torch 2.11.0, sglang @ g373cadc92).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Comment-only alignment of the Qwen3.5 recipes + touched code to the repo's
conventions (qwen3-8b recipes / surrounding code as the baseline):
- recipe experiment.yaml/raas.yaml: restore the "# -- ... --" section banners
  and the auto-derives / adaptive-availability rationale comments the dense
  recipes use; restyle the attn_impl comment to sentence-case prose.
- model.py: capitalize the qwen3_5 registry comment to match the file.
- fsdp_engine.py: reflow the attention_mask comment so dict(...) is not split
  mid-token, matching the file's other multi-line comments.
- pyproject.toml: reflow the sglang-extra comment to the column band.
- README: tag the GPU-layout code fence (text).
No functional content changed (verified: all yaml/sh/py/toml parse; only comment
lines were added/removed).

Also refresh the recipe README install note, which still described SGLang as
pinned at 0.5.12.post1 before the git-build pin landed.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
… 3.6.1

Two pyproject fixes so a clean install reproduces the validated astraflow35 env:
- Add flash-linear-attention==0.5.0 to the sglang extra: the fast Gated-DeltaNet
  kernels Qwen3.5 GDN training used in the validated runs (optional --
  transformers falls back to a slower pure-torch path when absent). Pulls
  fla-core==0.5.0.
- networkx==3.3 -> 3.6.1 to match the validated env (the only pinned version that
  differed from astraflow35).

Verified: uv pip install --dry-run -e ".[sglang]" resolves cleanly (299
packages, exit 0) to the validated versions (flash-linear-attention 0.5.0,
fla-core 0.5.0, networkx 3.6.1, transformers 5.8.1, sglang @ 373cadc9, ...).
All other pinned ML versions already matched astraflow35; the ~60 loose utility
deps are left flexible per the repo's convention.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Switch the dense Qwen3 math recipes (qwen3-1.7b-m2po-2gpus-{full,delta},
qwen3-8b-m2po-{full,delta}) and the cli_args attn_impl default to
kernels-community/flash-attn2 -- a prebuilt, ABI-matched FlashAttention-2
kernel from the HF kernels hub (cached on first use, no source build).

The literal flash_attention_2 loads the local flash-attn wheel, which is
ABI-broken on torch 2.11+cu130 (undefined symbol); is_flash_attn_2_available()
only checks package metadata so it reports available and then crashes on import.
The kernels-hub variant is ABI-matched and is the upstream-faithful FA2 the
recipes were tuned with. A paired step-25 A/B/C on qwen3-1.7b-m2po-2gpus-full
gave overall avg@k FA2 >= sdpa >= sdpa+recompute_logprob, all within eval noise;
FA2 varlen also derives the packed block-diagonal mask from cu_seqlens, avoiding
sdpa's reliance on position_ids resets. Qwen3.5 recipes stay on sdpa.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…ild)

Switch the [sglang] extra from the git build (373cadc9) to the published
sglang==0.5.13.post1 -- the release the recipes were recently validated against,
which ships qwen3_5 support (sglang/srt/models/qwen3_5.py) so it covers both the
Qwen3.5 and Qwen3-dense recipes. Update the Qwen3.5 READMEs' validated-stack
tables to match (sglang 0.5.13.post1, flashinfer 0.6.12). The flash-attn-4
pre-release / transformers 5.8.1 / kernels 0.14.x [tool.uv] overrides are
unchanged (0.5.13.post1 still pulls flash-attn-4 4.0.0b15).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Both full and delta variants re-run end-to-end on the published-release pin
(sglang==0.5.13.post1): training completes, full + delta weight transfer both
work, eval holds at baseline (~49-51% overall avg@k) -- no regression vs the
predecessor git build the step0->step80 table was produced on.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants