Skip to content

[Feature] Configurable loss aggregation level — token / seq / prompt mean (ScaleRL §3.2) #1423

Description

@EazyReal

Checklist

  • This feature will maintain backward compatibility with the current APIs in areal/api/. (Additive: one new actor.loss_aggregation field; default token_mean is byte-identical to today.)

Background

AReaL hardcodes the policy-gradient loss to a global token mean. ScaleRL (arXiv:2510.13786, §3.2) treats the loss-aggregation level as a tunable axis that changes which unit dominates the gradient — and reports it as a meaningful knob for stability. AReaL exposes none of it, so users cannot reproduce GRPO-style (per-sequence) or MiniMax-M1-style (per-prompt) aggregation without patching the loss.

Potential Solution

Add actor.loss_aggregation selecting the reduction level:

level reduction unit weighted equally source
token_mean (default) Σ(pg·m)/Σm token DAPO
seq_mean (new) mean_i(Σ_t pg·m / |o_i|) sequence GRPO
prompt_mean (new) mean_g(Σ pg·m / Σ m) prompt-group MiniMax-M1

token_mean is byte-identical to today.

  • One seam, no new machinery. seq_mean and prompt_mean share a single per-unit reduction in aggregate_pg_loss (unit = one sequence, or group_size consecutive sequences). Paired with a _make_loss_weight_fn that returns the unit count, AReaL's existing engine contract Σ(loss_mb·w_mb)/Σw_mb realizes the exact global mean over that unit — each mode is just a (reduction, weight) pair, with no cross-microbatch denominators.
  • Single source of truth. For prompt_mean the prompt-group size is gconfig.n_samples, so group_size is derived (not a separate knob), mb_spec.granularity is auto-bumped to a multiple of it, and rollout.min_valid_group_size is raised to it (prompt_mean groups positionally and needs whole groups). The only user knob is actor.loss_aggregation.

Additional Information

Implemented in #1417 (draft). Because prompt_mean groups positionally, it depends on the partial-group reward/advantage fix in #1416 and is stacked on it. Mutation-verified tests in tests/test_prompt_mean_loss.py (per-mode values, packed==padded, and a three-mode pairing invariant tying loss_fn/loss_weight_fn to the engine reduction).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions