Add Qwen3.5-4B math recipes and make Qwen3 dense recipes run on the transformers-5.8 / torch-2.11 stack#21
Open
iamziyuzhao wants to merge 11 commits into
Conversation
New recipes examples/math/qwen3.5-4b-m2po-{full,delta} for training Qwen3.5
(dense, hybrid Gated-DeltaNet text backbone; model_type qwen3_5) with M2PO on
AstraFlow, mirroring the existing qwen3-8b-m2po recipe structure. Trained
text-only for math (the checkpoint ships as Qwen3_5ForConditionalGeneration).
Verified end-to-end on NVIDIA L40 (Ada, sm_89): a full run trained 86+ steps
with no crash and steadily rising eval — overall avg@k 47.8% -> 57.4% (+9.6) and
pass@k 56.5% -> 67.4% over the first 80 steps (AIME24/AIME25/AMC/Minerva/MATH500,
eval every 10 steps).
Minimal framework changes for Qwen3.5 / transformers>=5 compatibility:
- model.py: register qwen3_5 + is_qwen3_5_model()
- fsdp_engine.py: pass attention_mask=None for qwen3_5 (transformers>=5
create_causal_mask calls .ndim; the old dict form raised AttributeError)
- fsdp/__init__.py: normalize _no_split_modules set->list (qwen3_5 exposes a set)
- rlvr.py: unwrap BatchEncoding from apply_chat_template (transformers>=5)
Recipes use the standard packed training forward. attention_backend=flashinfer
(fa3 dispatches a Hopper-only kernel that fails on Ada/L40 for the GDN path;
flashinfer + triton both verified); max_running_requests=32, mem_fraction_static=0.7
on inference and FSDP dp=4 + max_tokens_per_mb=8192 on the trainer to fit 44GB L40.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The version bump (transformers 5.8.1, torch 2.11+cu130, sglang dev) broke the
existing plain-Qwen3 math recipes (qwen3-1.7b, qwen3-8b). Two fixes:
1. fsdp_engine.py: always pass attention_mask=None for the packed/varlen forward.
The old dict(full_attention=None, sliding_attention=None) form is a
transformers-4.x relic; on transformers>=5 a dense model (qwen3/qwen2) treats
that dict as a *precomputed* mask, skips creation, and crashes ('dict' object
has no attribute 'ndim'). None is correct for all archs (dense, moe, vl,
qwen3.5/GDN) — masking is driven by cu_seqlens + position_ids. Subsumes the
prior qwen3_5/moe/vl special-case (drop now-unused imports).
2. recipes + cli_args: set attn_impl: sdpa in actor+ref for qwen3-1.7b and
qwen3-8b recipes (flash_attn is ABI-broken vs torch 2.11+cu130 -> import crash;
default was flash_attention_2). Expand attn_impl choices to sdpa/eager.
Verified end-to-end: qwen3-1.7b full + delta train cleanly to step 100 on the
bumped stack; math500 avg@k rises 73.0->77.0 (full) / 72.2->78.8 (delta);
importance_weight=1.0000 (packed forward correct); zero crashes.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
- Remove the unused is_qwen3_5_model() helper (its only caller, the mask-branch special-case, went away when the packed-forward mask was made unconditional). - Note in model.py why qwen3_5 is in VALID_VISION_MODELS (the checkpoint ships as Qwen3_5ForConditionalGeneration -> loaded via the ImageTextToText path). - Simplify the packed-forward attention_mask comment in fsdp_engine.py. No behavior change. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
b294f1e to
86c8ffb
Compare
Document the validated runtime stack (transformers 5.8.1, kernels 0.14.1, SGLang dev with qwen3_5 + TritonGDNKernel, fla 0.5.0, flashinfer 0.6.11.post1, torch 2.11.0+cu130, attn_impl sdpa), GPU layout, run commands, and the validated eval (overall avg@k 47.8 -> 57.4, pass@k 56.5 -> 67.4 over 80 steps). The default pyproject pins load qwen3_5, but the validated stack is installed out of band; a pin bump is deferred to a separate, tested PR (the validated SGLang is a dev build, not a clean release). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
86c8ffb to
56ea705
Compare
Bump the pyproject transformers pin (core dep + uv override) from 5.6.1 to the 5.8.1 that the Qwen3.5 and Qwen3-dense math recipes were validated on, and move the coupled kernels constraint from <0.13 to >=0.14,<0.15 (what transformers 5.8.1 was validated against). torch is already 2.11.0; flashinfer comes in transitively via sglang. sglang stays pinned at the published 0.5.12.post1 — the Qwen3.5 inference path was validated on an sglang dev build that ships qwen3_5, as noted in the recipe README. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
….8.1) The published sglang 0.5.12.post1 predates qwen3_5 and pins transformers 5.6.x, so it is incompatible with the transformers==5.8.1 the recipes require. Pin sglang to the validated main-branch build (sgl-project/sglang @ 373cadc9): it ships qwen3_5 + TritonGDNKernel and itself requires transformers==5.8.1, flashinfer 0.6.11.post1, torch 2.11.0, kernels<0.15 -- all matching the validated env. Update the [tool.uv] comments to match. Verified: every pin/override matches the working astraflow35 env (transformers 5.8.1, kernels 0.14.1, outlines-core 0.1.26, torch 2.11.0, sglang @ g373cadc92). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Comment-only alignment of the Qwen3.5 recipes + touched code to the repo's conventions (qwen3-8b recipes / surrounding code as the baseline): - recipe experiment.yaml/raas.yaml: restore the "# -- ... --" section banners and the auto-derives / adaptive-availability rationale comments the dense recipes use; restyle the attn_impl comment to sentence-case prose. - model.py: capitalize the qwen3_5 registry comment to match the file. - fsdp_engine.py: reflow the attention_mask comment so dict(...) is not split mid-token, matching the file's other multi-line comments. - pyproject.toml: reflow the sglang-extra comment to the column band. - README: tag the GPU-layout code fence (text). No functional content changed (verified: all yaml/sh/py/toml parse; only comment lines were added/removed). Also refresh the recipe README install note, which still described SGLang as pinned at 0.5.12.post1 before the git-build pin landed. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
… 3.6.1 Two pyproject fixes so a clean install reproduces the validated astraflow35 env: - Add flash-linear-attention==0.5.0 to the sglang extra: the fast Gated-DeltaNet kernels Qwen3.5 GDN training used in the validated runs (optional -- transformers falls back to a slower pure-torch path when absent). Pulls fla-core==0.5.0. - networkx==3.3 -> 3.6.1 to match the validated env (the only pinned version that differed from astraflow35). Verified: uv pip install --dry-run -e ".[sglang]" resolves cleanly (299 packages, exit 0) to the validated versions (flash-linear-attention 0.5.0, fla-core 0.5.0, networkx 3.6.1, transformers 5.8.1, sglang @ 373cadc9, ...). All other pinned ML versions already matched astraflow35; the ~60 loose utility deps are left flexible per the repo's convention. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Switch the dense Qwen3 math recipes (qwen3-1.7b-m2po-2gpus-{full,delta},
qwen3-8b-m2po-{full,delta}) and the cli_args attn_impl default to
kernels-community/flash-attn2 -- a prebuilt, ABI-matched FlashAttention-2
kernel from the HF kernels hub (cached on first use, no source build).
The literal flash_attention_2 loads the local flash-attn wheel, which is
ABI-broken on torch 2.11+cu130 (undefined symbol); is_flash_attn_2_available()
only checks package metadata so it reports available and then crashes on import.
The kernels-hub variant is ABI-matched and is the upstream-faithful FA2 the
recipes were tuned with. A paired step-25 A/B/C on qwen3-1.7b-m2po-2gpus-full
gave overall avg@k FA2 >= sdpa >= sdpa+recompute_logprob, all within eval noise;
FA2 varlen also derives the packed block-diagonal mask from cu_seqlens, avoiding
sdpa's reliance on position_ids resets. Qwen3.5 recipes stay on sdpa.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…ild) Switch the [sglang] extra from the git build (373cadc9) to the published sglang==0.5.13.post1 -- the release the recipes were recently validated against, which ships qwen3_5 support (sglang/srt/models/qwen3_5.py) so it covers both the Qwen3.5 and Qwen3-dense recipes. Update the Qwen3.5 READMEs' validated-stack tables to match (sglang 0.5.13.post1, flashinfer 0.6.12). The flash-attn-4 pre-release / transformers 5.8.1 / kernels 0.14.x [tool.uv] overrides are unchanged (0.5.13.post1 still pulls flash-attn-4 4.0.0b15). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Both full and delta variants re-run end-to-end on the published-release pin (sglang==0.5.13.post1): training completes, full + delta weight transfer both work, eval holds at baseline (~49-51% overall avg@k) -- no regression vs the predecessor git build the step0->step80 table was produced on. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR makes AstraFlow run both the Qwen3 dense math recipes and new
Qwen3.5 math recipes on the current transformers-5 / torch-2.11 stack.
It supersedes #11 (which added the Qwen3.5 recipe but left the existing dense
Qwen3 math recipes broken on transformers ≥ 5). Targets
dev.What changed
1. Make Qwen3 dense recipes run on transformers ≥ 5 (
fsdp_engine.py, recipes,cli_args.py)attention_mask=Nonefor allarchs. The old
dict(full_attention=None, sliding_attention=None)form is atransformers-4.x relic: on transformers ≥ 5 a dense model (qwen3/qwen2) treats
that dict as a precomputed mask, skips creation, and crashes.
Noneiscorrect — masking is driven by
cu_seqlens+position_ids. (Verified morecorrect than the old form even where it didn't crash: under SDPA the old dict
path silently lost block-diagonal separation.)
qwen3-1.7b-m2po-2gpus-{full,delta},qwen3-8b-m2po-{full,delta})and the
cli_args.pydefault now setattn_impl: kernels-community/flash-attn2—a prebuilt FlashAttention-2 kernel pulled from the HF kernels hub (build
variant ABI-matched to torch 2.11+cu130). The literal
flash_attention_2isavoided because it loads the local flash-attn wheel, which is ABI-broken vs
torch 2.11+cu130;
sdpa/eagerremain validchoices. (FA2-kernels andsdpawere confirmed eval-equivalent within single-sample noise; FA2 is thefaster default.)
2. Add Qwen3.5-4B math recipes + enablement
examples/math/qwen3.5-4b-m2po-{full,delta}/recipes (M2PO, ctx 8k,DeepScaleR,
math_verify), with a README documenting the validated stack.Qwen3.5 is a hybrid Gated-DeltaNet + attention checkpoint; the trainer uses
attn_impl: sdpa(GDN linear-attn viaflakernels, full-attn via sdpa) andinference uses
attention_backend: flashinfer.model.py: registerqwen3_5inVALID_VISION_MODELS(the checkpoint shipsas
Qwen3_5ForConditionalGeneration, loaded via the ImageTextToText path;trained text-only).
Validation
Qwen3-1.7B dense, transformers 5.8.1: full + delta trained 300 steps, eval
@0/100/200/300. math500 avg@k: full 73.8 → 79.9, delta 72.4 → 80.2; overall
avg@k: full 31.3 → 38.8, delta 31.1 → 38.3.
importance_weight ≈ 1.0every step(packed forward correct); zero crashes; resumed cleanly from an external kill.
Qwen3.5-4B (from #11): trains end-to-end and eval rises steadily.
Qwen3.5-4B re-validated on the pinned
0.5.13.post1release: bothfulland
deltare-run end-to-end onsglang==0.5.13.post1— training completes,full (
shard_copy) and delta (~7× compressed) weight transfer both work, andeval holds at baseline (~49–51% overall avg@k over a short run). No regression
vs the predecessor git build the table above came from.
Validated environment (Qwen3.5)
torch 2.11.0+cu130, transformers 5.8.1, kernels 0.14.1, SGLang 0.5.13.post1
(published release with
qwen3_5support) served viaattention_backend: flashinfer, flash-linear-attention 0.5.0, flashinfer-python 0.6.12,attn_impl=sdpa.pyproject.tomlpins the full validated stack:transformers==5.8.1(withkernels>=0.14,<0.15),torch==2.11.0,flash-linear-attention==0.5.0(thefast Qwen3.5 GDN kernels; optional, with a pure-torch fallback), and
sglang==0.5.13.post1— the published release that shipsqwen3_5(theolder 0.5.12.post1 predated it), which pulls flashinfer 0.6.12 in automatically.
uv pip install -e ".[sglang]"resolves cleanly to exactly the validatedversions.