Skip to content

[BUG] Partial rollout groups silently corrupt GRPO group normalization #1419

Description

@EazyReal

Checklist

  • The error occurs when using our provided Docker image. — N/A: framework-level numerical logic, reproduced on CPU without Docker (snippet below).
  • I can consistently reproduce the bug across multiple trials or random seeds. — deterministic.
  • If the error causes experiment abortion, I've verified that this error is the root cause, not a secondary error caused by peer workers. — N/A: this is a silent correctness bug (no crash/abortion); root cause confirmed in Normalization.

Detailed Information

Describe the bug

What's affected. Group-level reward/advantage normalization — Normalization with mean_level/std_level = "group", i.e. the GRPO/RLOO baseline used by default configs such as examples/math/gsm8k_grpo.yaml:

reward_norm: { mean_level: group, std_level: group, group_size: ${gconfig.n_samples} }

Context (how partial groups arise). GroupedRolloutWorkflow runs group_size episodes per prompt; when some return None it keeps the survivors and logs "...using remaining results" — producing a partial group with fewer than group_size rows.

Root cause. Normalization recovers groups by a fixed group_size stride, assuming every group has exactly group_size rows:

for i in range(bs // group_size):
    s = slice(i * group_size, (i + 1) * group_size)   # assumes each group is exactly group_size rows

Once any group is partial, this silently mis-groups the rows:

  • Trailing partial group — its rows fall past the last full stride, are never normalized, keep std = 0, and get divided by eps → advantages explode (~1e5 / NaN).
  • Mid-batch partial group — shifts every later boundary, so subsequent groups mix rows from different prompts. This happens even when the batch size is still divisible by group_size — e.g. group sizes [4, 4, 3, 1] sum to 12 (divisible by 4) yet only the first two groups are aligned. A contaminated baseline can flip the sign of an advantage, updating the policy in the wrong direction.

So a bs % group_size == 0 check is necessary but not sufficient — the real invariant is that each group's actual rows are normalized together.

Expected behavior

Each prompt's group is normalized within its own rows, regardless of partial/unequal groups; full (non-partial) groups stay bit-identical.

Before / after (minimal, CPU-only, no training run)

import torch
from areal.api.cli_args import NormConfig
from areal.utils.data import Normalization

norm = Normalization(NormConfig(mean_level="group", std_level="group", group_size=4))

# groups [4, 4, 3, 1] (bs=12). Rows 8-10 = prompt C = [0,1,2]; row 11 = prompt D = 100 (outlier).
x = torch.tensor([0.5, -0.5, 1., -1.,  2., -2., 0., 3.,  0., 1., 2.,  100.])

norm(x)[10]                          # BEFORE: -0.480  (row 10's baseline mixed with prompt D=100 -> sign FLIPPED)
norm(x, group_sizes=[4, 4, 3, 1])[10]  # AFTER:  +1.000  (baseline = mean(prompt C) = 1.0 -> correct)

Row 10 (value 2, which is above its own prompt's mean of 1) should have a positive advantage. Positional slicing lumps prompt D's outlier into the baseline and flips it to -0.480. The trailing case ([4, 4, 3], bs=11) instead leaves the last 3 rows un-normalized → ~1e5-scale advantages.

To Reproduce

Commit ID

main (still present on current main).

Environment

Framework-level; reproducible on CPU with the snippet above. Affects any config with group-level reward_norm/adv_norm.

Script

The snippet above — no cluster/GPU required.

Resolution / scope

This issue tracks correctly and configurably supporting partial rollout groups. Two parts:

  • Correctness fixfix(ppo): group-normalize by actual group sizes for partial groups #1415. Normalize by the actual per-group row counts: Normalization gains an optional group_sizes argument and buckets by it; PPOActor.compute_advantages threads the rollout's traj_group_sizes (each trajectory is one prompt group). Without group_sizes the positional path now raises on a non-divisible batch instead of mis-slicing. Reduction math is unchanged → full groups are bit-identical.
  • Optional drop policyfeat(rollout): add min_valid_group_size to drop under-filled rollout groups #1416 (enhancement built on the fix). min_valid_group_size lets a run require a minimum number of survivors and otherwise drop the whole group instead of keeping a partial one (default 1 = current behavior; set to gconfig.n_samples for full groups only).

Both changes are backward compatible: group_sizes is optional, and min_valid_group_size defaults to 1.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions