Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions src/prime_rl/configs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pydantic_config import BaseConfig

from prime_rl.configs.shared import BaseModelConfig, SlurmConfig
from prime_rl.utils.parsers import resolve_reasoning_parser, resolve_tool_call_parser
from prime_rl.utils.utils import rgetattr, rsetattr

# TODO: Set thinking/ solution budget
Expand Down Expand Up @@ -41,7 +42,12 @@ def __str__(self) -> str:


class ModelConfig(BaseModelConfig):
"""Configures the inference model. Most arguments are passed directly to the vLLM LLM class (https://docs.vllm.ai/en/latest/api/vllm.LLM.html)."""
"""Configures the inference model. Most arguments are passed directly to the vLLM LLM class (https://docs.vllm.ai/en/latest/api/vllm.LLM.html).

Parser fields (tool_call_parser, reasoning_parser) default to "auto",
which resolves to a concrete parser name at validation time based on
the model name. Set to None to disable.
"""

dtype: Annotated[
Literal["auto", "float16", "bfloat16", "float32"],
Expand Down Expand Up @@ -82,17 +88,16 @@ class ModelConfig(BaseModelConfig):
tool_call_parser: Annotated[
str | None,
Field(
description="The tool call parser to use. Passed to vLLM as `--tool-call-parser`. "
'Set to "auto" to infer from the model name.',
description='The tool call parser to use. Set to "auto" (default) to detect from the model name, or None to disable.',
),
] = "auto"

reasoning_parser: Annotated[
str | None,
Field(
description="Parser for extracting reasoning content from model outputs. Passed to vLLM as `--reasoning-parser`. Setting this enables reasoning mode.",
description='Parser for extracting reasoning content from model outputs. Set to "auto" (default) to detect from the model name, or None to disable.',
),
] = None
] = "auto"

rope_scaling: Annotated[
dict[str, Any] | str | None,
Expand All @@ -101,6 +106,15 @@ class ModelConfig(BaseModelConfig):
),
] = None

@model_validator(mode="after")
def auto_resolve_parsers(self):
"""Resolve "auto" parser values to concrete parser names from model name."""
if self.tool_call_parser == "auto":
self.tool_call_parser = resolve_tool_call_parser(self.name)
if self.reasoning_parser == "auto":
self.reasoning_parser = resolve_reasoning_parser(self.name)
return self


class WeightBroadcastConfig(BaseConfig):
"""Configures weight broadcast settings."""
Expand Down Expand Up @@ -545,14 +559,19 @@ def to_vllm(self) -> Namespace:
# Set `logprobs_mode` to `processed_logprobs` by default
rsetattr(namespace, "logprobs_mode", "processed_logprobs")

# Remove chat_template if not set (vLLM doesn't accept None)
if namespace.chat_template is None:
delattr(namespace, "chat_template")
# Remove tool_call_parser if not set (vLLM doesn't accept None)
if namespace.tool_call_parser is None:
delattr(namespace, "tool_call_parser")
namespace.enable_auto_tool_choice = hasattr(namespace, "tool_call_parser")

# Remove reasoning_parser if not set (vLLM doesn't accept None)
if namespace.reasoning_parser is None:
delattr(namespace, "reasoning_parser")

# Remove chat_template if not set (vLLM doesn't accept None)
if namespace.chat_template is None:
delattr(namespace, "chat_template")

# Remove lora_target_modules if not set (vLLM doesn't accept None)
if hasattr(namespace, "lora_target_modules") and namespace.lora_target_modules is None:
delattr(namespace, "lora_target_modules")
Expand Down
115 changes: 0 additions & 115 deletions src/prime_rl/inference/vllm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,117 +22,6 @@
from prime_rl.configs.inference import InferenceConfig
from prime_rl.utils.logger import get_logger

MODEL_TOOL_CALL_PARSER: dict[str, str] = {
# GLM-4.5
"zai-org/GLM-4.5": "glm45",
"zai-org/GLM-4.5-FP8": "glm45",
"zai-org/GLM-4.5-Base": "glm45",
"zai-org/GLM-4.5-Air": "glm45",
"zai-org/GLM-4.5-Air-FP8": "glm45",
"zai-org/GLM-4.5-Air-Base": "glm45",
"zai-org/GLM-4.5V": "glm45",
"zai-org/GLM-4.5V-FP8": "glm45",
# GLM-4.7
"zai-org/GLM-4.7": "glm47",
"zai-org/GLM-4.7-FP8": "glm47",
"zai-org/GLM-4.7-Flash": "glm47",
# GLM-5
"zai-org/GLM-5": "glm47",
"zai-org/GLM-5-FP8": "glm47",
# GLM-5.1
"zai-org/GLM-5.1": "glm47",
"zai-org/GLM-5.1-FP8": "glm47",
# MiniMax M2
"MiniMaxAI/MiniMax-M2": "minimax_m2",
"MiniMaxAI/MiniMax-M2.1": "minimax_m2",
"MiniMaxAI/MiniMax-M2.5": "minimax_m2",
# INTELLECT-3
"PrimeIntellect/INTELLECT-3": "hermes",
"PrimeIntellect/INTELLECT-3-FP8": "hermes",
"PrimeIntellect/INTELLECT-3.1": "hermes",
# Qwen3 dense
"Qwen/Qwen3-0.6B": "hermes",
"Qwen/Qwen3-0.6B-Base": "hermes",
"Qwen/Qwen3-0.6B-FP8": "hermes",
"Qwen/Qwen3-1.7B": "hermes",
"Qwen/Qwen3-1.7B-Base": "hermes",
"Qwen/Qwen3-1.7B-FP8": "hermes",
"Qwen/Qwen3-4B": "hermes",
"Qwen/Qwen3-4B-Base": "hermes",
"Qwen/Qwen3-4B-FP8": "hermes",
"Qwen/Qwen3-8B": "hermes",
"Qwen/Qwen3-8B-Base": "hermes",
"Qwen/Qwen3-8B-FP8": "hermes",
"Qwen/Qwen3-14B": "hermes",
"Qwen/Qwen3-14B-Base": "hermes",
"Qwen/Qwen3-14B-FP8": "hermes",
"Qwen/Qwen3-32B": "hermes",
"Qwen/Qwen3-32B-FP8": "hermes",
# Qwen3 MoE
"Qwen/Qwen3-30B-A3B": "hermes",
"Qwen/Qwen3-30B-A3B-Base": "hermes",
"Qwen/Qwen3-30B-A3B-FP8": "hermes",
"Qwen/Qwen3-235B-A22B": "hermes",
"Qwen/Qwen3-235B-A22B-FP8": "hermes",
# Qwen3 2507
"Qwen/Qwen3-4B-Instruct-2507": "hermes",
"Qwen/Qwen3-4B-Thinking-2507": "hermes",
"Qwen/Qwen3-4B-Instruct-2507-FP8": "hermes",
"Qwen/Qwen3-4B-Thinking-2507-FP8": "hermes",
"Qwen/Qwen3-30B-A3B-Instruct-2507": "hermes",
"Qwen/Qwen3-30B-A3B-Thinking-2507": "hermes",
"Qwen/Qwen3-30B-A3B-Instruct-2507-FP8": "hermes",
"Qwen/Qwen3-30B-A3B-Thinking-2507-FP8": "hermes",
"Qwen/Qwen3-235B-A22B-Instruct-2507": "hermes",
"Qwen/Qwen3-235B-A22B-Thinking-2507": "hermes",
"Qwen/Qwen3-235B-A22B-Instruct-2507-FP8": "hermes",
"Qwen/Qwen3-235B-A22B-Thinking-2507-FP8": "hermes",
# Qwen3-Next
"Qwen/Qwen3-Next-80B-A3B-Instruct": "hermes",
"Qwen/Qwen3-Next-80B-A3B-Thinking": "hermes",
"Qwen/Qwen3-Next-80B-A3B-Instruct-FP8": "hermes",
"Qwen/Qwen3-Next-80B-A3B-Thinking-FP8": "hermes",
# Qwen3-Coder
"Qwen/Qwen3-Coder-480B-A35B-Instruct": "hermes",
"Qwen/Qwen3-Coder-480B-A35B-Instruct-FP8": "hermes",
"Qwen/Qwen3-Coder-30B-A3B-Instruct": "hermes",
"Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8": "hermes",
# Qwen3-Coder-Next
"Qwen/Qwen3-Coder-Next": "hermes",
"Qwen/Qwen3-Coder-Next-Base": "hermes",
"Qwen/Qwen3-Coder-Next-FP8": "hermes",
# Qwen3.5 dense (uses qwen3_coder tool format, not hermes)
"Qwen/Qwen3.5-0.8B": "qwen3_coder",
"Qwen/Qwen3.5-0.8B-Base": "qwen3_coder",
"Qwen/Qwen3.5-2B": "qwen3_coder",
"Qwen/Qwen3.5-2B-Base": "qwen3_coder",
"Qwen/Qwen3.5-4B": "qwen3_coder",
"Qwen/Qwen3.5-4B-Base": "qwen3_coder",
"Qwen/Qwen3.5-9B": "qwen3_coder",
"Qwen/Qwen3.5-9B-Base": "qwen3_coder",
"Qwen/Qwen3.5-27B": "qwen3_coder",
"Qwen/Qwen3.5-27B-FP8": "qwen3_coder",
# Qwen3.5 MoE (uses qwen3_coder tool format, not hermes)
"Qwen/Qwen3.5-35B-A3B": "qwen3_coder",
"Qwen/Qwen3.5-35B-A3B-Base": "qwen3_coder",
"Qwen/Qwen3.5-35B-A3B-FP8": "qwen3_coder",
"Qwen/Qwen3.5-122B-A10B": "qwen3_coder",
"Qwen/Qwen3.5-122B-A10B-FP8": "qwen3_coder",
"Qwen/Qwen3.5-397B-A17B": "qwen3_coder",
"Qwen/Qwen3.5-397B-A17B-FP8": "qwen3_coder",
# NemotronH
"nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16": "qwen3_coder",
"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16": "qwen3_coder",
}


def resolve_tool_call_parser(model_name: str, tool_call_parser: str | None) -> str | None:
"""Resolve tool_call_parser from model name if set to "auto"."""
if tool_call_parser == "auto":
return MODEL_TOOL_CALL_PARSER.get(model_name)
return tool_call_parser


logger = get_logger()
from prime_rl.inference.patches import (
monkey_patch_harmony_stop_token_propagation,
Expand Down Expand Up @@ -340,11 +229,7 @@ def server(config: InferenceConfig, vllm_extra: dict[str, Any] | None = None):
assert args is not None
validate_parsed_serve_args(args)

args.tool_call_parser = resolve_tool_call_parser(args.model, args.tool_call_parser)
args.enable_auto_tool_choice = args.tool_call_parser is not None
args.reset_prefix_cache_after_update = config.experimental.reset_prefix_cache_after_update
if args.tool_call_parser is not None:
logger.info(f"Using tool_call_parser='{args.tool_call_parser}' for model '{args.model}'")

# Set the worker extension class based on the broadcast backend
args.worker_extension_cls = WORKER_EXTENSION_CLS[config.weight_broadcast.type]
Expand Down
45 changes: 45 additions & 0 deletions src/prime_rl/utils/parsers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import re

# (regex, parser_name) — first match wins.
TOOL_CALL_PARSER_PATTERNS: list[tuple[re.Pattern[str], str]] = [
(re.compile(r"^deepseek-ai/DeepSeek-V3\.2"), "deepseek_v32"),
(re.compile(r"^deepseek-ai/DeepSeek-V3\.1"), "deepseek_v31"),
(re.compile(r"^zai-org/GLM-4\.5"), "glm45"),
(re.compile(r"^zai-org/GLM-4\.7"), "glm47"),
(re.compile(r"^zai-org/GLM-5"), "glm47"),
(re.compile(r"^MiniMaxAI/MiniMax-M2"), "minimax_m2"),
(re.compile(r"^PrimeIntellect/INTELLECT-3"), "qwen3_coder"),
(re.compile(r"^nvidia/NVIDIA-Nemotron-3"), "qwen3_coder"),
(re.compile(r"^stepfun-ai/Step-3\.5"), "step3p5"),
# Qwen3.5 and Qwen3-Coder use qwen3_coder — must be before the Qwen3 catch-all
(re.compile(r"^Qwen/Qwen3\.5-"), "qwen3_coder"),
(re.compile(r"^Qwen/Qwen3-Coder"), "qwen3_coder"),
(re.compile(r"^Qwen/Qwen3-"), "hermes"),
]

REASONING_PARSER_PATTERNS: list[tuple[re.Pattern[str], str]] = [
(re.compile(r"^deepseek-ai/DeepSeek-V3\.[12]"), "deepseek_r1"),
(re.compile(r"^zai-org/GLM-"), "glm45"),
(re.compile(r"^MiniMaxAI/MiniMax-M2"), "minimax_m2_append_think"),
(re.compile(r"^PrimeIntellect/INTELLECT-3"), "deepseek_r1"),
(re.compile(r"^stepfun-ai/Step-3\.5"), "step3p5"),
# Only Qwen3 Thinking models reason — Instruct models do not
(re.compile(r"^Qwen/Qwen3-.*Thinking"), "deepseek_r1"),
(re.compile(r"^Qwen/Qwen3\.5-"), "qwen3"),
]


def resolve_parser(model_name: str, patterns: list[tuple[re.Pattern[str], str]]) -> str | None:
"""Auto-detect parser from model name. Returns the first matching pattern's parser."""
for pattern, parser_name in patterns:
if pattern.search(model_name):
return parser_name
return None


def resolve_tool_call_parser(model_name: str) -> str | None:
return resolve_parser(model_name, TOOL_CALL_PARSER_PATTERNS)


def resolve_reasoning_parser(model_name: str) -> str | None:
return resolve_parser(model_name, REASONING_PARSER_PATTERNS)
28 changes: 28 additions & 0 deletions tests/unit/test_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,31 @@ def test_subconfig_seq_len_wins_over_shared():
)
assert config.trainer.model.seq_len == 8192
assert config.orchestrator.seq_len == 4096


def test_shared_model_name_resolves_inference_parser():
"""Shared model.name propagates to inference before ModelConfig runs its parser auto-resolver."""
config = RLConfig.model_validate(
{
"model": {"name": "Qwen/Qwen3-Coder-30B-A3B-Instruct"},
"trainer": {},
"orchestrator": {},
"inference": {},
}
)
assert config.inference is not None
assert config.inference.model.name == "Qwen/Qwen3-Coder-30B-A3B-Instruct"
assert config.inference.model.tool_call_parser == "qwen3_coder"


def test_explicit_inference_parser_wins_over_auto():
config = RLConfig.model_validate(
{
"model": {"name": "Qwen/Qwen3-Coder-30B-A3B-Instruct"},
"trainer": {},
"orchestrator": {},
"inference": {"model": {"tool_call_parser": "hermes"}},
}
)
assert config.inference is not None
assert config.inference.model.tool_call_parser == "hermes"
Loading
Loading