Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 23 additions & 15 deletions astraflow/raas/engine/remote_inf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,14 @@ def __init__(
self.lock = Lock()

self.lora_initialized = False
# Versioned LoRA adapter naming: each weight sync loads under a NEW
# name (``lora_v{seq}``) and we never unload. Unloading an adapter that
# still has paused/aborted in-flight requests deadlocks on SGLang's
# ``wait_for_unload`` (aborted requests never release their usage
# counter). New unique names avoid the unload entirely; SGLang's
# mem-pool LRU evicts stale adapters from GPU automatically.
self._lora_seq = 0
self._current_lora_name: str | None = None

self._executor: ProcessPoolExecutor | None = None
self._paused: bool = False
Expand Down Expand Up @@ -654,7 +662,7 @@ async def agenerate(self, req: ModelRequest) -> ModelResponse:
f"agenerate() building HTTP request, rid={req.rid}, "
f"iteration={iteration}, server_addr={server_addr}"
)
http_req = self.backend.build_generation_request(req, self.lora_initialized)
http_req = self.backend.build_generation_request(req, self._current_lora_name)

# Loop until the generation is complete
logger.debug(
Expand Down Expand Up @@ -745,19 +753,25 @@ def load_weights_from_path(
For full weights: ``/update_weights_from_disk`` includes
``abort_all_requests: True`` and ``flush_cache`` internally.

For LoRA adapters (``use_lora=True``): unloads the old adapter,
loads the new one, then flushes the KV cache via ``/flush_cache``
to discard stale entries computed with the old LoRA weights.
Relies on sglang releasing the ``lora_registry`` counter for
aborted requests (fixed upstream in
``TokenizerManager._handle_abort_finish_reason`` as of 0.5.12).
For LoRA adapters (``use_lora=True``): loads the new adapter under a
fresh versioned name (``lora_v{seq}``) without unloading the previous
one, then flushes the KV cache. We never unload because unloading an
adapter with paused/aborted in-flight requests deadlocks SGLang's
``wait_for_unload``; SGLang's mem-pool LRU reclaims GPU slots for
stale adapters automatically (bounded by ``max_loras_per_batch``).
"""
import time as _time

_t0 = _time.monotonic()
lora_name = "lora_1"

if use_lora:
# Load under a NEW versioned name and do NOT unload the old one.
# Unloading an adapter with paused/aborted in-flight requests
# deadlocks SGLang's ``wait_for_unload`` (the usage counter for
# aborted requests is never released). A fresh name has no such
# counter; SGLang's mem-pool LRU evicts stale adapters from GPU.
self._lora_seq += 1
lora_name = f"lora_v{self._lora_seq}"
logger.info(
"load_weights_from_path: sending /load_lora_adapter "
"to %d servers (path=%s, lora_name=%s) ...",
Expand All @@ -766,19 +780,13 @@ def load_weights_from_path(
lora_name,
)
try:
if self.lora_initialized:
unload_req = HttpRequest(
endpoint="/unload_lora_adapter",
payload={"lora_name": lora_name},
)
self._run_request_on_all_servers(unload_req)

load_req = HttpRequest(
endpoint="/load_lora_adapter",
payload={"lora_name": lora_name, "lora_path": str(path)},
)
self._run_request_on_all_servers(load_req)
self.lora_initialized = True
self._current_lora_name = lora_name

# Flush stale KV cache entries computed with old LoRA weights.
# Safe because caller already paused generation (is_pause=True
Expand Down
12 changes: 8 additions & 4 deletions astraflow/raas/engine/sglang_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@ class SGLangBackend:
"""Backend that translates engine operations into SGLang HTTP API calls."""

def build_generation_request(
self, req: ModelRequest, with_lora: bool
self, req: ModelRequest, lora_name: str | None
) -> HttpRequest:
"""Convert a ModelRequest into an SGLang /generate HTTP request."""
"""Convert a ModelRequest into an SGLang /generate HTTP request.

``lora_name`` is the currently-active versioned adapter name (e.g.
``lora_v3``) or ``None`` when no adapter is loaded.
"""
gconfig = req.gconfig
stop_token_ids = gconfig.stop_token_ids
stop = gconfig.stop
Expand Down Expand Up @@ -55,8 +59,8 @@ def build_generation_request(
"stream": False,
}

if with_lora:
payload["lora_path"] = "lora_1"
if lora_name:
payload["lora_path"] = lora_name

return HttpRequest(endpoint="/generate", payload=payload)

Expand Down
15 changes: 12 additions & 3 deletions astraflow/raas/engine/vllm_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,14 @@ def __init__(self):
pass

def build_generation_request(
self, req: ModelRequest, with_lora: bool
self, req: ModelRequest, lora_name: str | None
) -> HttpRequest:
"""Convert a ModelRequest into a vLLM completions or chat HTTP request."""
"""Convert a ModelRequest into a vLLM completions or chat HTTP request.

``lora_name`` is a truthy marker that a LoRA is active; vLLM selects
the adapter via ``gconfig.lora_name`` (its own naming), so the marker's
value is unused here.
"""
gconfig = req.gconfig
stop_token_ids = gconfig.stop_token_ids
stop = gconfig.stop
Expand All @@ -54,7 +59,7 @@ def build_generation_request(
if stop:
payload["stop"] = stop

if with_lora and len(gconfig.lora_name) > 0:
if lora_name and len(gconfig.lora_name) > 0:
payload["model"] = gconfig.lora_name

if req.vision_msg_vllm:
Expand Down Expand Up @@ -181,6 +186,10 @@ def __init__(self, config: InferenceEngineConfig):
self.config = config
self._engine = RemoteInfEngine(config, VLLMBackend())
self._engine.lora_initialized = config.use_lora
# vLLM selects the adapter via gconfig.lora_name; this just marks LoRA
# active so the shared generation-request builder passes a truthy flag.
if config.use_lora:
self._engine._current_lora_name = "vllm_lora"

def __getattr__(self, name: str):
return getattr(self._engine, name)
Expand Down
6 changes: 5 additions & 1 deletion astraflow/raas/server/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1815,11 +1815,15 @@ async def _do_weight_update(
model_id, exc_info=True,
)

# Sync LoRA state to eval engines
# Sync LoRA state to eval engines. They share the same SGLang server,
# so they must use the same versioned adapter name in generation.
if use_lora:
main_inner = getattr(engine, "_engine", engine)
cur_name = getattr(main_inner, "_current_lora_name", None)
for eval_eng in self._eval_engines.values():
inner = getattr(eval_eng, "_engine", eval_eng)
inner.lora_initialized = True
inner._current_lora_name = cur_name

_timing = (
f"notify_version: loaded {model_id} v={version} "
Expand Down
8 changes: 8 additions & 0 deletions docs/en/recipes/math.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ bash examples/math/qwen3-1.7b-m2po-2gpus-full/scripts/run_qwen3-1.7b-m2po-2gpus-
| Train dataset | DeepScaleR |
| Eval datasets | AIME24, AIME25, AMC, Minerva Math, MATH500 |

### LoRA variant

[`qwen3-1.7b-m2po-2gpus-lora/`](https://github.com/Infini-AI-Lab/astraflow/tree/main/examples/math/qwen3-1.7b-m2po-2gpus-lora) trains a LoRA adapter on the actor instead of full fine-tuning, keeping the same 2-GPU layout. Each step the trainer syncs the adapter to the SGLang server under a fresh versioned name (`lora_v{n}`) and never unloads it — SGLang's memory-pool LRU reclaims old versions — which avoids the unload deadlock that occurs when an adapter still holds aborted in-flight requests. One important caveat: a LoRA update is effectively much larger than a full-fine-tuning step at the same learning rate (the `alpha/rank` scaling), so LoRA needs near-on-policy training to stay stable. The recipe therefore sets `ppo_n_minibatches: 1`, `max_staleness: 1`, and `recompute_logprob: true` (with `lr` 5e-6); with these it shows a clean rising eval curve. On each weight sync the server first pauses generation and drains its in-flight requests (aborting any still running), then loads the new adapter under the fresh versioned name and flushes the stale KV cache before resuming — the old adapter is deliberately never unloaded, because unloading one that still holds aborted requests would block SGLang's `wait_for_unload` forever. Run it with:

```bash
bash examples/math/qwen3-1.7b-m2po-2gpus-lora/scripts/run_qwen3-1.7b-m2po-2gpus-lora.sh
```

## Qwen3-8B — 8 GPUs

The full-scale recipe. It needs an 8-GPU node — 4 GPUs for inference, 4 for training — and also comes in full and delta transfer variants:
Expand Down
36 changes: 36 additions & 0 deletions examples/math/qwen3-1.7b-m2po-2gpus-lora/scripts/1_astraflow.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#!/bin/bash
set -euo pipefail
# [1/3] Launch AstraFlow HTTP service
#
# Usage (terminal 1):
# bash examples/math/qwen3-1.7b-m2po-2gpus-lora/scripts/1_astraflow.sh

export CUDA_VISIBLE_DEVICES=""

SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)"
cd "${REPO_ROOT}"
export PYTHONPATH="${REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}"

YAML_DIR="${SCRIPT_DIR}/yaml"
export EXPERIMENT_CONFIG="${EXPERIMENT_CONFIG:-${YAML_DIR}/experiment.yaml}"
source "${REPO_ROOT}/examples/_common/utils.sh"
# Export EXP_NAME and TRIAL_NAME from the experiment YAML.
astraflow_load_experiment_env

export ASTRAFLOW_HOST="${ASTRAFLOW_HOST:-0.0.0.0}"
export ASTRAFLOW_PORT="${ASTRAFLOW_PORT:-8000}"

# NCCL / PYTORCH / WANDB tweaks + LOG_DIR. Defined in examples/_common/utils.sh.
astraflow_setup_env

echo "=== AstraFlow HTTP Service ==="
echo "Experiment config : ${EXPERIMENT_CONFIG}"
echo "Port : ${ASTRAFLOW_PORT}"
echo "==============================="

python3 -u -m astraflow \
--config "${EXPERIMENT_CONFIG}" \
--port "${ASTRAFLOW_PORT}" \
--host "${ASTRAFLOW_HOST}" \
2>&1 | tee "${LOG_DIR}/astraflow.log"
44 changes: 44 additions & 0 deletions examples/math/qwen3-1.7b-m2po-2gpus-lora/scripts/2_raas.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#!/bin/bash
set -euo pipefail
# [2/3] Launch RaaS inference server (SGLang + TCP receiver)
#
# Usage (terminal 2, after AstraFlow is ready):
# bash examples/math/qwen3-1.7b-m2po-2gpus-lora/scripts/2_raas.sh

SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)"
cd "${REPO_ROOT}"
export PYTHONPATH="${REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}"

YAML_DIR="${SCRIPT_DIR}/yaml"
export EXPERIMENT_CONFIG="${EXPERIMENT_CONFIG:-${YAML_DIR}/experiment.yaml}"
export RAAS_CONFIG="${RAAS_CONFIG:-${YAML_DIR}/raas.yaml}"
source "${REPO_ROOT}/examples/_common/utils.sh"
# Export EXP_NAME and TRIAL_NAME from the experiment YAML.
astraflow_load_experiment_env

export CUDA_VISIBLE_DEVICES="${SERVICE_CUDA_VISIBLE_DEVICES:-0}"
export RAAS_HOST="${RAAS_HOST:-0.0.0.0}"
export RAAS_PORT="${RAAS_PORT:-19190}"
export ASTRAFLOW_PORT="${ASTRAFLOW_PORT:-8000}"
export ASTRAFLOW_URL="${ASTRAFLOW_URL:-http://127.0.0.1:${ASTRAFLOW_PORT}}"

# NCCL / PYTORCH / WANDB tweaks + LOG_DIR. Defined in examples/_common/utils.sh.
astraflow_setup_env

echo "=== RaaS Inference Server (SGLang + TCP receiver) ==="
echo "Experiment config : ${EXPERIMENT_CONFIG}"
echo "RaaS config : ${RAAS_CONFIG}"
echo "GPUs : ${CUDA_VISIBLE_DEVICES}"
echo "Port : ${RAAS_PORT}"
echo "AstraFlow URL : ${ASTRAFLOW_URL}"
echo "======================================================="

python3 -u -m astraflow.raas.server \
--host "${RAAS_HOST}" \
--port "${RAAS_PORT}" \
--config "${EXPERIMENT_CONFIG}" \
--config "${RAAS_CONFIG}" \
--engine-id "${ENGINE_ID:-default}" \
--astraflow-url "${ASTRAFLOW_URL}" \
2>&1 | tee "${LOG_DIR}/raas.log"
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/bin/bash
set -euo pipefail
# [3/3] Launch Trainer for model0 (TCP, sender_agent on local_rank 0)
#
# Usage (terminal 3, after AstraFlow and RaaS are ready):
# bash examples/math/qwen3-1.7b-m2po-2gpus-lora/scripts/3_trainer_model0.sh

SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)"
cd "${REPO_ROOT}"
export PYTHONPATH="${REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}"

YAML_DIR="${SCRIPT_DIR}/yaml"
export EXPERIMENT_CONFIG="${EXPERIMENT_CONFIG:-${YAML_DIR}/experiment.yaml}"
source "${REPO_ROOT}/examples/_common/utils.sh"
# Export EXP_NAME and TRIAL_NAME from the experiment YAML.
astraflow_load_experiment_env

export CUDA_VISIBLE_DEVICES="${TRAINER_MODEL0_GPUS:-1}"
TRAINER0_NPROC="$(echo "${CUDA_VISIBLE_DEVICES}" | awk -F',' '{print NF}')"

export RAAS_PORT="${RAAS_PORT:-19190}"
export ASTRAFLOW_PORT="${ASTRAFLOW_PORT:-8000}"
export ASTRAFLOW_URL="http://127.0.0.1:${ASTRAFLOW_PORT}"
export ASTRAFLOW_RAAS_URL="http://127.0.0.1:${RAAS_PORT}"

# sender_agent (in trainer) listens on this HTTP port
export WEIGHT_TRANSFER_HTTP_PORT="${WEIGHT_TRANSFER_HTTP_PORT_MODEL0:-19861}"

# NCCL / PYTORCH / WANDB tweaks + LOG_DIR. Defined in examples/_common/utils.sh.
astraflow_setup_env

echo "=== Trainer model0 (TCP) ==="
echo "Experiment config : ${EXPERIMENT_CONFIG}"
echo "GPUs : ${CUDA_VISIBLE_DEVICES} (FSDP dp${TRAINER0_NPROC})"
echo "AstraFlow : ${ASTRAFLOW_URL}"
echo "RaaS : ${ASTRAFLOW_RAAS_URL}"
echo "Sender HTTP : ${WEIGHT_TRANSFER_HTTP_PORT}"
echo "WANDB mode : ${WANDB_MODE:-online}"
echo "=========================================="

torchrun --nnodes 1 --nproc-per-node "${TRAINER0_NPROC}" \
--master-addr "${MASTER_ADDR:-127.0.0.1}" --master-port "${MASTER_PORT_MODEL0:-29541}" \
examples/launch_trainer.py \
--config "${EXPERIMENT_CONFIG}" \
--trainer trainer_model0 \
"$@" 2>&1 | tee "${LOG_DIR}/trainer_model0.log"
Loading
Loading