Checklist
Detailed Information
Describe the bug
What's affected. Group-level reward/advantage normalization — Normalization with mean_level/std_level = "group", i.e. the GRPO/RLOO baseline used by default configs such as examples/math/gsm8k_grpo.yaml:
reward_norm: { mean_level: group, std_level: group, group_size: ${gconfig.n_samples} }
Context (how partial groups arise). GroupedRolloutWorkflow runs group_size episodes per prompt; when some return None it keeps the survivors and logs "...using remaining results" — producing a partial group with fewer than group_size rows.
Root cause. Normalization recovers groups by a fixed group_size stride, assuming every group has exactly group_size rows:
for i in range(bs // group_size):
s = slice(i * group_size, (i + 1) * group_size) # assumes each group is exactly group_size rows
Once any group is partial, this silently mis-groups the rows:
- Trailing partial group — its rows fall past the last full stride, are never normalized, keep
std = 0, and get divided by eps → advantages explode (~1e5 / NaN).
- Mid-batch partial group — shifts every later boundary, so subsequent groups mix rows from different prompts. This happens even when the batch size is still divisible by
group_size — e.g. group sizes [4, 4, 3, 1] sum to 12 (divisible by 4) yet only the first two groups are aligned. A contaminated baseline can flip the sign of an advantage, updating the policy in the wrong direction.
So a bs % group_size == 0 check is necessary but not sufficient — the real invariant is that each group's actual rows are normalized together.
Expected behavior
Each prompt's group is normalized within its own rows, regardless of partial/unequal groups; full (non-partial) groups stay bit-identical.
Before / after (minimal, CPU-only, no training run)
import torch
from areal.api.cli_args import NormConfig
from areal.utils.data import Normalization
norm = Normalization(NormConfig(mean_level="group", std_level="group", group_size=4))
# groups [4, 4, 3, 1] (bs=12). Rows 8-10 = prompt C = [0,1,2]; row 11 = prompt D = 100 (outlier).
x = torch.tensor([0.5, -0.5, 1., -1., 2., -2., 0., 3., 0., 1., 2., 100.])
norm(x)[10] # BEFORE: -0.480 (row 10's baseline mixed with prompt D=100 -> sign FLIPPED)
norm(x, group_sizes=[4, 4, 3, 1])[10] # AFTER: +1.000 (baseline = mean(prompt C) = 1.0 -> correct)
Row 10 (value 2, which is above its own prompt's mean of 1) should have a positive advantage. Positional slicing lumps prompt D's outlier into the baseline and flips it to -0.480. The trailing case ([4, 4, 3], bs=11) instead leaves the last 3 rows un-normalized → ~1e5-scale advantages.
To Reproduce
Commit ID
main (still present on current main).
Environment
Framework-level; reproducible on CPU with the snippet above. Affects any config with group-level reward_norm/adv_norm.
Script
The snippet above — no cluster/GPU required.
Resolution / scope
This issue tracks correctly and configurably supporting partial rollout groups. Two parts:
Both changes are backward compatible: group_sizes is optional, and min_valid_group_size defaults to 1.
Checklist
Normalization.Detailed Information
Describe the bug
What's affected. Group-level reward/advantage normalization —
Normalizationwithmean_level/std_level = "group", i.e. the GRPO/RLOO baseline used by default configs such asexamples/math/gsm8k_grpo.yaml:Context (how partial groups arise).
GroupedRolloutWorkflowrunsgroup_sizeepisodes per prompt; when some returnNoneit keeps the survivors and logs"...using remaining results"— producing a partial group with fewer thangroup_sizerows.Root cause.
Normalizationrecovers groups by a fixedgroup_sizestride, assuming every group has exactlygroup_sizerows:Once any group is partial, this silently mis-groups the rows:
std = 0, and get divided byeps→ advantages explode (~1e5 / NaN).group_size— e.g. group sizes[4, 4, 3, 1]sum to 12 (divisible by 4) yet only the first two groups are aligned. A contaminated baseline can flip the sign of an advantage, updating the policy in the wrong direction.So a
bs % group_size == 0check is necessary but not sufficient — the real invariant is that each group's actual rows are normalized together.Expected behavior
Each prompt's group is normalized within its own rows, regardless of partial/unequal groups; full (non-partial) groups stay bit-identical.
Before / after (minimal, CPU-only, no training run)
Row 10 (value
2, which is above its own prompt's mean of1) should have a positive advantage. Positional slicing lumps prompt D's outlier into the baseline and flips it to-0.480. The trailing case ([4, 4, 3], bs=11) instead leaves the last 3 rows un-normalized → ~1e5-scale advantages.To Reproduce
Commit ID
main(still present on currentmain).Environment
Framework-level; reproducible on CPU with the snippet above. Affects any config with group-level
reward_norm/adv_norm.Script
The snippet above — no cluster/GPU required.
Resolution / scope
This issue tracks correctly and configurably supporting partial rollout groups. Two parts:
Normalizationgains an optionalgroup_sizesargument and buckets by it;PPOActor.compute_advantagesthreads the rollout'straj_group_sizes(each trajectory is one prompt group). Withoutgroup_sizesthe positional path now raises on a non-divisible batch instead of mis-slicing. Reduction math is unchanged → full groups are bit-identical.min_valid_group_sizelets a run require a minimum number of survivors and otherwise drop the whole group instead of keeping a partial one (default1= current behavior; set togconfig.n_samplesfor full groups only).Both changes are backward compatible:
group_sizesis optional, andmin_valid_group_sizedefaults to1.