Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions astraflow/train_worker/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,10 +473,22 @@ class TrainEngineConfig:
trial_name: str = ""
path: str = field(default="", metadata={"help": "Path to HuggingFace checkpoint"})
attn_impl: str = field(
default="flash_attention_2",
default="kernels-community/flash-attn2",
metadata={
"help": "Attention implementation for huggingface transformers model.",
"choices": ["flash_attention_2"],
"help": (
"Attention implementation for huggingface transformers model. "
"Default pulls a prebuilt FlashAttention-2 kernel from the HF kernels "
"hub (ABI-matched to torch, incl. varlen for packed sequences). The "
"literal 'flash_attention_2' loads the local flash-attn wheel, which is "
"ABI-broken on torch>=2.11+cu13; 'sdpa' works but relies on position_ids "
"resets for packed block-diagonal masking."
),
"choices": [
"kernels-community/flash-attn2",
"flash_attention_2",
"sdpa",
"eager",
],
},
)
init_from_scratch: bool = field(
Expand Down
21 changes: 9 additions & 12 deletions astraflow/train_worker/engine/fsdp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,6 @@
from astraflow.train_worker.utils.model import (
disable_dropout_in_model,
is_gemma3_model,
is_qwen3_moe_model,
is_qwen3_vl_model,
is_qwen_vl_model,
is_valid_vision_model,
)
Expand Down Expand Up @@ -1206,16 +1204,15 @@ def _prepare_mb_list(self, input_: dict[str, Any]) -> MicroBatchList:
]
mb["use_cache"] = False
padded_mb["use_cache"] = False
if is_qwen3_moe_model(self.model_config.model_type) or is_qwen3_vl_model(
self.model_config.model_type
):
mb["attention_mask"] = None
padded_mb["attention_mask"] = None
else:
mb["attention_mask"] = dict(full_attention=None, sliding_attention=None)
padded_mb["attention_mask"] = dict(
full_attention=None, sliding_attention=None
)
# Always pass attention_mask=None for the packed/varlen forward: per-sequence
# causal masking is driven by cu_seqlens + position_ids, and the model builds
# the right mask from None. The old dict(full_attention=None,
# sliding_attention=None) form is a transformers-4.x relic: on transformers>=5
# a dense model (qwen3 / qwen2) treats that dict as a *precomputed* mask, skips
# creation, and crashes. Passing None lets the model build its mask from
# cu_seqlens + position_ids instead.
mb["attention_mask"] = None
padded_mb["attention_mask"] = None
if "multi_modal_input" in mb:
image_grid_thw_list = [
item["image_grid_thw"]
Expand Down
3 changes: 3 additions & 0 deletions astraflow/train_worker/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
"qwen2_vl",
"qwen2_5_vl",
"qwen3_vl",
# Qwen3.5 dense math checkpoints ship as Qwen3_5ForConditionalGeneration, so they
# load via the ImageTextToText path even though these recipes train text-only.
"qwen3_5",
"gemma3",
]
# Registry of vision models verified to work with this framework.
Expand Down
18 changes: 18 additions & 0 deletions examples/math/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,21 @@ Complete guidance: [`docs/en/recipes/math.md`](../../docs/en/recipes/math.md).

Most math recipes default to one 8xH100 node. The `qwen3-1.7b-m2po-2gpus-*`
recipes are smaller 2xH100 variants.

---
**Attention kernel**

The dense Qwen3 recipes (`qwen3-1.7b-m2po-2gpus-*`, `qwen3-8b-m2po-*`) set
`attn_impl: kernels-community/flash-attn2` — a prebuilt, ABI-matched
FlashAttention-2 kernel pulled from the Hugging Face `kernels` hub (fetched and
cached on first use; no source build). This is the working FA2 on the validated
stack (`torch 2.11+cu130`): the literal `attn_impl: flash_attention_2` would
instead load the local `flash-attn` wheel and crash with an `undefined symbol`
ABI error (`is_flash_attn_2_available()` is metadata-only, so it never catches
the broken import). It is also the same kernel as `cli_args.py`'s default, so
recipes that omit `attn_impl` get it too.

`sdpa` and `eager` remain available; `sdpa` works but relies on per-sequence
`position_ids` resets for packed block-diagonal masking, whereas FA2 varlen
derives the block-diagonal mask from `cu_seqlens` directly. The Qwen3.5 recipes
use `sdpa` (hybrid Gated-DeltaNet + attention model).
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ trainer_base:
data_parallel_size: 1

actor:
attn_impl: kernels-community/flash-attn2
gradient_checkpointing: true
mb_spec:
max_tokens_per_mb: 17408
Expand All @@ -135,6 +136,7 @@ trainer_base:
adv_norm: { mean_level: batch, std_level: batch }

ref:
attn_impl: kernels-community/flash-attn2
mb_spec:
max_tokens_per_mb: 17408

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ trainer_base:
data_parallel_size: 1

actor:
attn_impl: kernels-community/flash-attn2
gradient_checkpointing: true
mb_spec:
max_tokens_per_mb: 17408
Expand All @@ -134,6 +135,7 @@ trainer_base:
adv_norm: { mean_level: batch, std_level: batch }

ref:
attn_impl: kernels-community/flash-attn2
mb_spec:
max_tokens_per_mb: 17408

Expand Down
2 changes: 2 additions & 0 deletions examples/math/qwen3-8b-m2po-delta/yaml/experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ trainer_base:
data_parallel_size: 4

actor:
attn_impl: kernels-community/flash-attn2
gradient_checkpointing: true
mb_spec:
max_tokens_per_mb: 17408
Expand All @@ -135,6 +136,7 @@ trainer_base:
adv_norm: { mean_level: batch, std_level: batch }

ref:
attn_impl: kernels-community/flash-attn2
mb_spec:
max_tokens_per_mb: 17408

Expand Down
2 changes: 2 additions & 0 deletions examples/math/qwen3-8b-m2po-full/yaml/experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ trainer_base:
data_parallel_size: 4

actor:
attn_impl: kernels-community/flash-attn2
gradient_checkpointing: true
mb_spec:
max_tokens_per_mb: 17408
Expand All @@ -134,6 +135,7 @@ trainer_base:
adv_norm: { mean_level: batch, std_level: batch }

ref:
attn_impl: kernels-community/flash-attn2
mb_spec:
max_tokens_per_mb: 17408

Expand Down
17 changes: 17 additions & 0 deletions examples/math/qwen3.5-4b-m2po-delta/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Qwen3.5-4B — Math RL (M2PO), delta weight transfer

Same recipe as [`qwen3.5-4b-m2po-full`](../qwen3.5-4b-m2po-full/README.md), but
the trainer pushes **only changed weights** to the inference engine each sync
(`weight_transfer_strategies: delta`) instead of the full model.

See the [full recipe's README](../qwen3.5-4b-m2po-full/README.md) for the
validated environment (transformers 5.8.1 / kernels 0.14.1 / SGLang
`0.5.13.post1` with `qwen3_5`, `attention_backend: flashinfer` / `fla` 0.5.0 /
torch 2.11.0+cu130),
GPU layout, install note, and validation results.

## Run

```bash
bash examples/math/qwen3.5-4b-m2po-delta/scripts/run_qwen3.5-4b-m2po-delta.sh
```
36 changes: 36 additions & 0 deletions examples/math/qwen3.5-4b-m2po-delta/scripts/1_astraflow.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#!/bin/bash
set -euo pipefail
# [1/3] Launch AstraFlow HTTP service
#
# Usage (terminal 1):
# bash examples/math/qwen3.5-4b-m2po-delta/scripts/1_astraflow.sh

export CUDA_VISIBLE_DEVICES=""

SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)"
cd "${REPO_ROOT}"
export PYTHONPATH="${REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}"

YAML_DIR="${SCRIPT_DIR}/yaml"
export EXPERIMENT_CONFIG="${EXPERIMENT_CONFIG:-${YAML_DIR}/experiment.yaml}"
source "${REPO_ROOT}/examples/_common/utils.sh"
# Export EXP_NAME and TRIAL_NAME from the experiment YAML.
astraflow_load_experiment_env

export ASTRAFLOW_HOST="${ASTRAFLOW_HOST:-0.0.0.0}"
export ASTRAFLOW_PORT="${ASTRAFLOW_PORT:-8000}"

# NCCL / PYTORCH / WANDB tweaks + LOG_DIR. Defined in examples/_common/utils.sh.
astraflow_setup_env

echo "=== AstraFlow HTTP Service ==="
echo "Experiment config : ${EXPERIMENT_CONFIG}"
echo "Port : ${ASTRAFLOW_PORT}"
echo "==============================="

python3 -u -m astraflow \
--config "${EXPERIMENT_CONFIG}" \
--port "${ASTRAFLOW_PORT}" \
--host "${ASTRAFLOW_HOST}" \
2>&1 | tee "${LOG_DIR}/astraflow.log"
44 changes: 44 additions & 0 deletions examples/math/qwen3.5-4b-m2po-delta/scripts/2_raas.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#!/bin/bash
set -euo pipefail
# [2/3] Launch RaaS inference server (SGLang + TCP receiver)
#
# Usage (terminal 2, after AstraFlow is ready):
# bash examples/math/qwen3.5-4b-m2po-delta/scripts/2_raas.sh

SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)"
cd "${REPO_ROOT}"
export PYTHONPATH="${REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}"

YAML_DIR="${SCRIPT_DIR}/yaml"
export EXPERIMENT_CONFIG="${EXPERIMENT_CONFIG:-${YAML_DIR}/experiment.yaml}"
export RAAS_CONFIG="${RAAS_CONFIG:-${YAML_DIR}/raas.yaml}"
source "${REPO_ROOT}/examples/_common/utils.sh"
# Export EXP_NAME and TRIAL_NAME from the experiment YAML.
astraflow_load_experiment_env

export CUDA_VISIBLE_DEVICES="${SERVICE_CUDA_VISIBLE_DEVICES:-0,1,2,3}"
export RAAS_HOST="${RAAS_HOST:-0.0.0.0}"
export RAAS_PORT="${RAAS_PORT:-19190}"
export ASTRAFLOW_PORT="${ASTRAFLOW_PORT:-8000}"
export ASTRAFLOW_URL="${ASTRAFLOW_URL:-http://127.0.0.1:${ASTRAFLOW_PORT}}"

# NCCL / PYTORCH / WANDB tweaks + LOG_DIR. Defined in examples/_common/utils.sh.
astraflow_setup_env

echo "=== RaaS Inference Server (SGLang + TCP receiver) ==="
echo "Experiment config : ${EXPERIMENT_CONFIG}"
echo "RaaS config : ${RAAS_CONFIG}"
echo "GPUs : ${CUDA_VISIBLE_DEVICES}"
echo "Port : ${RAAS_PORT}"
echo "AstraFlow URL : ${ASTRAFLOW_URL}"
echo "======================================================="

python3 -u -m astraflow.raas.server \
--host "${RAAS_HOST}" \
--port "${RAAS_PORT}" \
--config "${EXPERIMENT_CONFIG}" \
--config "${RAAS_CONFIG}" \
--engine-id "${ENGINE_ID:-default}" \
--astraflow-url "${ASTRAFLOW_URL}" \
2>&1 | tee "${LOG_DIR}/raas.log"
47 changes: 47 additions & 0 deletions examples/math/qwen3.5-4b-m2po-delta/scripts/3_trainer_model0.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/bin/bash
set -euo pipefail
# [3/3] Launch Trainer for model0 (TCP, sender_agent on local_rank 0)
#
# Usage (terminal 3, after AstraFlow and RaaS are ready):
# bash examples/math/qwen3.5-4b-m2po-delta/scripts/3_trainer_model0.sh

SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)"
cd "${REPO_ROOT}"
export PYTHONPATH="${REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}"

YAML_DIR="${SCRIPT_DIR}/yaml"
export EXPERIMENT_CONFIG="${EXPERIMENT_CONFIG:-${YAML_DIR}/experiment.yaml}"
source "${REPO_ROOT}/examples/_common/utils.sh"
# Export EXP_NAME and TRIAL_NAME from the experiment YAML.
astraflow_load_experiment_env

export CUDA_VISIBLE_DEVICES="${TRAINER_MODEL0_GPUS:-4,5,6,7}"
TRAINER0_NPROC="$(echo "${CUDA_VISIBLE_DEVICES}" | awk -F',' '{print NF}')"

export RAAS_PORT="${RAAS_PORT:-19190}"
export ASTRAFLOW_PORT="${ASTRAFLOW_PORT:-8000}"
export ASTRAFLOW_URL="http://127.0.0.1:${ASTRAFLOW_PORT}"
export ASTRAFLOW_RAAS_URL="http://127.0.0.1:${RAAS_PORT}"

# sender_agent (in trainer) listens on this HTTP port
export WEIGHT_TRANSFER_HTTP_PORT="${WEIGHT_TRANSFER_HTTP_PORT_MODEL0:-19861}"

# NCCL / PYTORCH / WANDB tweaks + LOG_DIR. Defined in examples/_common/utils.sh.
astraflow_setup_env

echo "=== Trainer model0 (TCP) ==="
echo "Experiment config : ${EXPERIMENT_CONFIG}"
echo "GPUs : ${CUDA_VISIBLE_DEVICES} (FSDP dp${TRAINER0_NPROC})"
echo "AstraFlow : ${ASTRAFLOW_URL}"
echo "RaaS : ${ASTRAFLOW_RAAS_URL}"
echo "Sender HTTP : ${WEIGHT_TRANSFER_HTTP_PORT}"
echo "WANDB mode : ${WANDB_MODE:-online}"
echo "=========================================="

torchrun --nnodes 1 --nproc-per-node "${TRAINER0_NPROC}" \
--master-addr "${MASTER_ADDR:-127.0.0.1}" --master-port "${MASTER_PORT_MODEL0:-29541}" \
examples/launch_trainer.py \
--config "${EXPERIMENT_CONFIG}" \
--trainer trainer_model0 \
"$@" 2>&1 | tee "${LOG_DIR}/trainer_model0.log"
Loading