feat(grpo): add stop_tool_names for immediate agent loop termination#5390
feat(grpo): add stop_tool_names for immediate agent loop termination#5390lailanelkoussy wants to merge 10 commits intohuggingface:mainfrom
Conversation
Adds GRPOConfig.stop_tool_names: list[str] | None. When the model calls a tool whose name is in this list, that sample exits the tool-calling loop immediately after the tool result is appended — no further generation occurs for it. Other samples in the batch are unaffected. Fixes huggingface#5389
Adds GRPOConfig.stop_tool_names: list[str] | None. When the model calls a tool whose name is in this list, that sample exits the tool-calling loop immediately after the tool result is appended — no further generation occurs for it. Other samples in the batch are unaffected. Fixes huggingface#5389
There was a problem hiding this comment.
Pull request overview
Adds support in GRPO’s agent/tool execution loop to treat certain tools as “terminal” (e.g., final_answer), so that when they’re called the sample stops generating further turns immediately—aligning training-time behavior with eval-time agent frameworks.
Changes:
- Added
stop_tool_names: list[str] | NonetoGRPOConfig. - Updated
GRPOTrainer._tool_call_loopto detect stop tools per-sample and remove those samples from further generation after appending the tool result. - Added a dedicated test suite covering stop-tool termination behavior and no-op behavior when disabled.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
trl/trainer/grpo_trainer.py |
Implements stop-tool detection and per-sample early termination in the tool loop. |
trl/trainer/grpo_config.py |
Adds and documents the new stop_tool_names config field. |
tests/test_grpo_trainer_stop_tool.py |
Adds unit/integration tests for stop-tool termination semantics. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
trl/trainer/grpo_trainer.py
Outdated
| # Extend completion_ids to include tool-result tokens. | ||
| completion_ids[idx_with_tool] = pct[prompt_length:] | ||
| # Mask tool-result tokens out of the loss. | ||
| tool_mask[idx_with_tool] += [0] * tool_suffix_length | ||
| if logprobs is not None: | ||
| logprobs[idx_with_tool] += [0.0] * tool_suffix_length |
There was a problem hiding this comment.
Stop-tool samples update completion_ids to include the tool suffix, but this path doesn’t re-apply the max_completion_length truncation logic used later in the normal tool-flow. If a tool result (or its chat-template wrapper) is long, completion_ids / tool_mask / logprobs can exceed max_completion_length, which can increase memory/compute and break the expectation that completions are length-bounded. Consider truncating pct[prompt_length:] to self.max_completion_length here and trimming the corresponding mask/logprob extensions accordingly (mirroring the truncation logic used for the post-tool generation path).
| # Extend completion_ids to include tool-result tokens. | |
| completion_ids[idx_with_tool] = pct[prompt_length:] | |
| # Mask tool-result tokens out of the loss. | |
| tool_mask[idx_with_tool] += [0] * tool_suffix_length | |
| if logprobs is not None: | |
| logprobs[idx_with_tool] += [0.0] * tool_suffix_length | |
| # Build the full completion (completion + tool suffix) and enforce max_completion_length. | |
| full_completion = pct[prompt_length:] | |
| final_length = min(len(full_completion), self.max_completion_length) | |
| # Extend and truncate completion_ids to include (possibly truncated) tool-result tokens. | |
| completion_ids[idx_with_tool] = full_completion[:final_length] | |
| # Mask tool-result tokens out of the loss, keeping alignment with completion_ids. | |
| extended_tool_mask = tool_mask[idx_with_tool] + [0] * tool_suffix_length | |
| tool_mask[idx_with_tool] = extended_tool_mask[:final_length] | |
| if logprobs is not None: | |
| extended_logprobs = logprobs[idx_with_tool] + [0.0] * tool_suffix_length | |
| logprobs[idx_with_tool] = extended_logprobs[:final_length] |
trl/trainer/grpo_trainer.py
Outdated
| idxs_with_tool = [i for i, keep in zip(idxs_with_tool, non_stop_mask, strict=False) if keep] | ||
| prompt_completion_tools = [ | ||
| p for p, keep in zip(prompt_completion_tools, non_stop_mask, strict=False) if keep | ||
| ] | ||
| prompt_completion_tool_ids = [ | ||
| p for p, keep in zip(prompt_completion_tool_ids, non_stop_mask, strict=False) if keep |
There was a problem hiding this comment.
The filtering zips here use strict=False, but non_stop_mask is derived directly from idxs_with_tool, so the lengths should always match. Using strict=True (as done in nearby zips in this function) would fail fast if an invariant is broken instead of silently truncating, which makes debugging easier.
| idxs_with_tool = [i for i, keep in zip(idxs_with_tool, non_stop_mask, strict=False) if keep] | |
| prompt_completion_tools = [ | |
| p for p, keep in zip(prompt_completion_tools, non_stop_mask, strict=False) if keep | |
| ] | |
| prompt_completion_tool_ids = [ | |
| p for p, keep in zip(prompt_completion_tool_ids, non_stop_mask, strict=False) if keep | |
| idxs_with_tool = [i for i, keep in zip(idxs_with_tool, non_stop_mask, strict=True) if keep] | |
| prompt_completion_tools = [ | |
| p for p, keep in zip(prompt_completion_tools, non_stop_mask, strict=True) if keep | |
| ] | |
| prompt_completion_tool_ids = [ | |
| p for p, keep in zip(prompt_completion_tool_ids, non_stop_mask, strict=True) if keep |
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 2 potential issues.
Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
|
@codex can this be simplified? |
|
Codex couldn't complete this request. Try again later. |
The stop-tool code path appended tool-result suffix tokens to completion_ids without enforcing max_completion_length, causing the padded completion tensor to exceed the expected width and silently increasing memory and logits_to_keep compute. All other paths (regular post-tool at L1571-1585, overlong at L1517) already enforced this limit. Truncate full_completion, tool_mask, and logprobs to min(len, max_completion_length) to match the regular post-tool path. Also fix strict=False → strict=True in the three zip() calls filtering by non_stop_mask, consistent with every other zip(..., strict=True) in the file. Add two regression tests: - test_stop_tool_enforces_max_completion_length - test_stop_tool_truncation_preserves_mask_and_logprobs_alignment
…ilanelkoussy/trl into feat/stop-tool-agent-training
|
@albertvillanova I don't know how common is this stop_tool feature, do you know if it's something specific to smolagent or more widely used? if so I'd be happy to add this feature |
|
@qgallouedec I was also considering proposing a PR to support varying the available tools per rollout (i.e., per datapoint / sampled trajectory used for training). I’m currently fine-tuning a model with tools and validating checkpoints using smolagents, and I think it would benefit from exposing the model to more diverse toolsets across rollouts. |
|
Yes I think we should find a way to make the tools vary. I'm sure how yet, maybe with a column in the dataset? But in any case this is a separate concern and should not be combined with this PR |
|
Makes sense, thanks! I’ll open a separate PR for this. |
|
yep, that would be consistent with SFT/DPO etc, ok please let's continue this discussion in another issue or pr 👍 |

GRPOTrainer: add
stop_tool_namesfor immediate agent loop terminationFixes #5389
What this PR does
Adds a
stop_tool_names: list[str] | Nonefield toGRPOConfig. When the agent calls any tool whose name is in this list, that sample exits the_tool_call_loopimmediately after the tool result is appended — no further model generation occurs for it. Other samples in the batch continue unaffected.This enables
final_answer-style termination tools (as used in smolagents and similar frameworks) to work identically during training and evaluation, eliminating a systematic behavioural divergence.Changes
trl/trainer/grpo_config.pystop_tool_names: list[str] | Nonefield with docstringtrl/trainer/grpo_trainer.py_tool_call_loop: after tool execution and after the overlong filter, identify samples that called a stop tool, update theircompletion_ids/tool_mask/logprobsto include the tool result tokens (masked from loss), then remove them fromidxs_with_toolbefore_generate_single_turnExample
Behaviour
stop_tool_names=None(default): no change from current behaviour."tool"), those tokens are masked intool_mask(value 0, excluded from loss), and no new completion is generated for that sample.if not idxs_with_tool: breakpath.max_tool_calling_iterations: stop-tool termination takes precedence; the sample exits regardless of how many iterations remain.idxs_with_tool), so the stop-tool path only sees non-overlong samples.Testing
The test verifies:
final_answerexits the loop after 1 tool-calling iteration (notmax_tool_calling_iterations).final_answertool call + result, with no post-tool model tokens.tool_maskcorrectly marks the tool result tokens as 0.stop_tool_names=None, behaviour is unchanged.Note
Medium Risk
Touches GRPO’s tool-calling generation loop and modifies how
completion_ids/tool_mask/logprobsare built and truncated, which can affect training dynamics and memory usage. Behavior is gated behindstop_tool_names(defaultNone), but incorrect masking/truncation could impact loss/reward calculations when enabled.Overview
Adds a new
GRPOConfig.stop_tool_namesoption to treat specific tools (e.g.final_answer) as terminal during GRPO agent training, so samples stop generating immediately after that tool executes.Updates
GRPOTrainer._tool_call_loopto detect stop-tool calls, append the tool-result suffix intocompletion_ids, mask those tool-result tokens out of the loss (tool_mask=0), truncate consistently tomax_completion_length(includinglogprobsalignment), and remove those samples from further post-toolgenerateturns while letting other samples continue.Introduces a comprehensive test suite (
tests/test_grpo_trainer_stop_tool.py) that mocksmodel.generateto assert the reduced generation call count, mixed-batch behavior, reward-visible tool messages, finite loss, and truncation/mask/logprobs alignment, withxfailon oldertransformerswithout tool parsing.Written by Cursor Bugbot for commit 4349c67. This will update automatically on new commits. Configure here.