Skip to content

feat(grpo): add stop_tool_names for immediate agent loop termination#5390

Open
lailanelkoussy wants to merge 10 commits intohuggingface:mainfrom
lailanelkoussy:feat/stop-tool-agent-training
Open

feat(grpo): add stop_tool_names for immediate agent loop termination#5390
lailanelkoussy wants to merge 10 commits intohuggingface:mainfrom
lailanelkoussy:feat/stop-tool-agent-training

Conversation

@lailanelkoussy
Copy link
Copy Markdown
Contributor

@lailanelkoussy lailanelkoussy commented Mar 27, 2026

GRPOTrainer: add stop_tool_names for immediate agent loop termination

Fixes #5389

What this PR does

Adds a stop_tool_names: list[str] | None field to GRPOConfig. When the agent calls any tool whose name is in this list, that sample exits the _tool_call_loop immediately after the tool result is appended — no further model generation occurs for it. Other samples in the batch continue unaffected.

This enables final_answer-style termination tools (as used in smolagents and similar frameworks) to work identically during training and evaluation, eliminating a systematic behavioural divergence.

Changes

trl/trainer/grpo_config.py

  • Added stop_tool_names: list[str] | None field with docstring

trl/trainer/grpo_trainer.py

  • In _tool_call_loop: after tool execution and after the overlong filter, identify samples that called a stop tool, update their completion_ids / tool_mask / logprobs to include the tool result tokens (masked from loss), then remove them from idxs_with_tool before _generate_single_turn

Example

def final_answer(answer: int) -> int:
    """Submit the final answer to the task.
    Args:
        answer: The final integer result.
    Returns:
        The submitted answer.
    """
    return answer

config = GRPOConfig(
    max_tool_calling_iterations=15,
    stop_tool_names=["final_answer"],
    ...
)

trainer = GRPOTrainer(
    args=config,
    tools=[tool_a, tool_b, final_answer],
    ...
)

Behaviour

  • stop_tool_names=None (default): no change from current behaviour.
  • When a stop tool is called: the tool result is appended to the conversation (role "tool"), those tokens are masked in tool_mask (value 0, excluded from loss), and no new completion is generated for that sample.
  • If all samples in a batch call a stop tool in the same iteration, the loop exits early via the existing if not idxs_with_tool: break path.
  • Interaction with max_tool_calling_iterations: stop-tool termination takes precedence; the sample exits regardless of how many iterations remain.
  • Interaction with overlong filter: a sample that is both overlong and calls a stop tool is handled by the existing overlong path first (it is already removed from idxs_with_tool), so the stop-tool path only sees non-overlong samples.

Testing

pytest tests/test_grpo_trainer_stop_tool.py -sv

The test verifies:

  1. A model that calls final_answer exits the loop after 1 tool-calling iteration (not max_tool_calling_iterations).
  2. The completion contains the final_answer tool call + result, with no post-tool model tokens.
  3. tool_mask correctly marks the tool result tokens as 0.
  4. With stop_tool_names=None, behaviour is unchanged.

Note

Medium Risk
Touches GRPO’s tool-calling generation loop and modifies how completion_ids/tool_mask/logprobs are built and truncated, which can affect training dynamics and memory usage. Behavior is gated behind stop_tool_names (default None), but incorrect masking/truncation could impact loss/reward calculations when enabled.

Overview
Adds a new GRPOConfig.stop_tool_names option to treat specific tools (e.g. final_answer) as terminal during GRPO agent training, so samples stop generating immediately after that tool executes.

Updates GRPOTrainer._tool_call_loop to detect stop-tool calls, append the tool-result suffix into completion_ids, mask those tool-result tokens out of the loss (tool_mask=0), truncate consistently to max_completion_length (including logprobs alignment), and remove those samples from further post-tool generate turns while letting other samples continue.

Introduces a comprehensive test suite (tests/test_grpo_trainer_stop_tool.py) that mocks model.generate to assert the reduced generation call count, mixed-batch behavior, reward-visible tool messages, finite loss, and truncation/mask/logprobs alignment, with xfail on older transformers without tool parsing.

Written by Cursor Bugbot for commit 4349c67. This will update automatically on new commits. Configure here.

  Adds GRPOConfig.stop_tool_names: list[str] | None. When the model calls
  a tool whose name is in this list, that sample exits the tool-calling
  loop immediately after the tool result is appended — no further
  generation occurs for it. Other samples in the batch are unaffected.

  Fixes huggingface#5389
  Adds GRPOConfig.stop_tool_names: list[str] | None. When the model calls
  a tool whose name is in this list, that sample exits the tool-calling
  loop immediately after the tool result is appended — no further
  generation occurs for it. Other samples in the batch are unaffected.

  Fixes huggingface#5389
Copilot AI review requested due to automatic review settings March 27, 2026 19:31
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds support in GRPO’s agent/tool execution loop to treat certain tools as “terminal” (e.g., final_answer), so that when they’re called the sample stops generating further turns immediately—aligning training-time behavior with eval-time agent frameworks.

Changes:

  • Added stop_tool_names: list[str] | None to GRPOConfig.
  • Updated GRPOTrainer._tool_call_loop to detect stop tools per-sample and remove those samples from further generation after appending the tool result.
  • Added a dedicated test suite covering stop-tool termination behavior and no-op behavior when disabled.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.

File Description
trl/trainer/grpo_trainer.py Implements stop-tool detection and per-sample early termination in the tool loop.
trl/trainer/grpo_config.py Adds and documents the new stop_tool_names config field.
tests/test_grpo_trainer_stop_tool.py Adds unit/integration tests for stop-tool termination semantics.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +1544 to +1549
# Extend completion_ids to include tool-result tokens.
completion_ids[idx_with_tool] = pct[prompt_length:]
# Mask tool-result tokens out of the loss.
tool_mask[idx_with_tool] += [0] * tool_suffix_length
if logprobs is not None:
logprobs[idx_with_tool] += [0.0] * tool_suffix_length
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stop-tool samples update completion_ids to include the tool suffix, but this path doesn’t re-apply the max_completion_length truncation logic used later in the normal tool-flow. If a tool result (or its chat-template wrapper) is long, completion_ids / tool_mask / logprobs can exceed max_completion_length, which can increase memory/compute and break the expectation that completions are length-bounded. Consider truncating pct[prompt_length:] to self.max_completion_length here and trimming the corresponding mask/logprob extensions accordingly (mirroring the truncation logic used for the post-tool generation path).

Suggested change
# Extend completion_ids to include tool-result tokens.
completion_ids[idx_with_tool] = pct[prompt_length:]
# Mask tool-result tokens out of the loss.
tool_mask[idx_with_tool] += [0] * tool_suffix_length
if logprobs is not None:
logprobs[idx_with_tool] += [0.0] * tool_suffix_length
# Build the full completion (completion + tool suffix) and enforce max_completion_length.
full_completion = pct[prompt_length:]
final_length = min(len(full_completion), self.max_completion_length)
# Extend and truncate completion_ids to include (possibly truncated) tool-result tokens.
completion_ids[idx_with_tool] = full_completion[:final_length]
# Mask tool-result tokens out of the loss, keeping alignment with completion_ids.
extended_tool_mask = tool_mask[idx_with_tool] + [0] * tool_suffix_length
tool_mask[idx_with_tool] = extended_tool_mask[:final_length]
if logprobs is not None:
extended_logprobs = logprobs[idx_with_tool] + [0.0] * tool_suffix_length
logprobs[idx_with_tool] = extended_logprobs[:final_length]

Copilot uses AI. Check for mistakes.
Comment on lines +1550 to +1555
idxs_with_tool = [i for i, keep in zip(idxs_with_tool, non_stop_mask, strict=False) if keep]
prompt_completion_tools = [
p for p, keep in zip(prompt_completion_tools, non_stop_mask, strict=False) if keep
]
prompt_completion_tool_ids = [
p for p, keep in zip(prompt_completion_tool_ids, non_stop_mask, strict=False) if keep
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The filtering zips here use strict=False, but non_stop_mask is derived directly from idxs_with_tool, so the lengths should always match. Using strict=True (as done in nearby zips in this function) would fail fast if an invariant is broken instead of silently truncating, which makes debugging easier.

Suggested change
idxs_with_tool = [i for i, keep in zip(idxs_with_tool, non_stop_mask, strict=False) if keep]
prompt_completion_tools = [
p for p, keep in zip(prompt_completion_tools, non_stop_mask, strict=False) if keep
]
prompt_completion_tool_ids = [
p for p, keep in zip(prompt_completion_tool_ids, non_stop_mask, strict=False) if keep
idxs_with_tool = [i for i, keep in zip(idxs_with_tool, non_stop_mask, strict=True) if keep]
prompt_completion_tools = [
p for p, keep in zip(prompt_completion_tools, non_stop_mask, strict=True) if keep
]
prompt_completion_tool_ids = [
p for p, keep in zip(prompt_completion_tool_ids, non_stop_mask, strict=True) if keep

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 2 potential issues.

Fix All in Cursor

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

@qgallouedec
Copy link
Copy Markdown
Member

@codex can this be simplified?

@chatgpt-codex-connector
Copy link
Copy Markdown

Codex couldn't complete this request. Try again later.

The stop-tool code path appended tool-result suffix tokens to
completion_ids without enforcing max_completion_length, causing the
padded completion tensor to exceed the expected width and silently
increasing memory and logits_to_keep compute. All other paths (regular
post-tool at L1571-1585, overlong at L1517) already enforced this limit.

Truncate full_completion, tool_mask, and logprobs to
min(len, max_completion_length) to match the regular post-tool path.

Also fix strict=False → strict=True in the three zip() calls filtering
by non_stop_mask, consistent with every other zip(..., strict=True) in
the file.

Add two regression tests:
- test_stop_tool_enforces_max_completion_length
- test_stop_tool_truncation_preserves_mask_and_logprobs_alignment
@qgallouedec
Copy link
Copy Markdown
Member

@albertvillanova I don't know how common is this stop_tool feature, do you know if it's something specific to smolagent or more widely used? if so I'd be happy to add this feature

@lailanelkoussy
Copy link
Copy Markdown
Contributor Author

lailanelkoussy commented Mar 27, 2026

@qgallouedec I was also considering proposing a PR to support varying the available tools per rollout (i.e., per datapoint / sampled trajectory used for training). I’m currently fine-tuning a model with tools and validating checkpoints using smolagents, and I think it would benefit from exposing the model to more diverse toolsets across rollouts.
If you think this could be combined with this PR in a clean way, I’d be happy to work on it.

@qgallouedec
Copy link
Copy Markdown
Member

Yes I think we should find a way to make the tools vary. I'm sure how yet, maybe with a column in the dataset? But in any case this is a separate concern and should not be combined with this PR

@lailanelkoussy
Copy link
Copy Markdown
Contributor Author

Makes sense, thanks! I’ll open a separate PR for this.
My current idea is to add a tools column to the dataset, where each datapoint specifies the subset of tools (by name) available for that rollout, while keeping GRPOTrainer(tools=[...]) as the global pool and filtering it per rollout.

@qgallouedec
Copy link
Copy Markdown
Member

yep, that would be consistent with SFT/DPO etc, ok please let's continue this discussion in another issue or pr 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

# [Feature Request] GRPOTrainer: support stop-tool termination in agent training loop

3 participants