Skip to content

New SP support in Trainer#46775

Open
SunMarc wants to merge 4 commits into
mainfrom
sp-update
Open

New SP support in Trainer#46775
SunMarc wants to merge 4 commits into
mainfrom
sp-update

Conversation

@SunMarc

@SunMarc SunMarc commented Jun 19, 2026

Copy link
Copy Markdown
Member

What does this PR do?

This PR adds compatibility with the new native SP support from accelerate. With this, you can perform long context training with FSDPv2, DeepSpeed or even DDP.

Requires huggingface/accelerate#4084

cc @qgallouedec for viz

@SunMarc SunMarc requested a review from winglian June 19, 2026 16:27
@github-actions

Copy link
Copy Markdown
Contributor

CI Dashboard: View test results in Grafana

@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec

qgallouedec commented Jun 24, 2026

Copy link
Copy Markdown
Member

I tried with SFT, it's looking good!! A small bug though: it over-scales the LM loss by sp_size

Bug

Trainer._get_num_items_in_batch (transformers/trainer.py ~L2196):

num_items_in_batch = num_items_in_batch // pc.non_data_parallel_size   # = tp_size * cp_size * sp_size

With average_tokens_across_devices=True (default), num_items_in_batch is already gather().sum() over all ranks. Ulysses SP shards the sequence at the dataloader (SequenceShardingDataLoader), so the per-rank counts are disjoint and the gathered sum is the correct global token count. The // sp_size then under-counts the denominator → loss and gradients inflated by sp_size. (Correct for TP, which replicates the batch; wrong for SP, which shards it.)

Fix

divisor = pc.non_data_parallel_size
if pc.sp_enabled and pc.sp_backend == "accelerate":
    divisor //= pc.sp_size   # Ulysses shards the sequence at the dataloader -> gather is already global
num_items_in_batch = num_items_in_batch // divisor

Repro

(dp=2,sp=2) on 4 GPUs vs (dp=2,sp=1) on 2 GPUs: same global batch, same seed → must give the same loss.
At lr=0 (pure forward, isolates the loss math) the sp2 loss is 2.0–2.25× (≈sp_size) the sp1 loss every
step
. After the fix the ratio should be ~1.0.

# sp_loss_repro.py
import json, os
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, Trainer, TrainingArguments

tok = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
sequences = [[0, 1, 2], [10, 11, 12, 13], [20, 21, 22, 23, 24], [30, 31, 32]] * 64
dataset = Dataset.from_dict({"input_ids": sequences})
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B")
collator = DataCollatorForLanguageModeling(tokenizer=tok, mlm=False, pad_to_multiple_of=2)  # seq len % sp_size == 0

trainer = Trainer(
    model=model,
    args=TrainingArguments(
        output_dir="sp-loss-repro", per_device_train_batch_size=2, max_steps=30, logging_steps=1,
        learning_rate=0.0, seed=42, report_to="none",
    ),
    train_dataset=dataset, data_collator=collator,
)
trainer.train()
if trainer.accelerator.is_main_process:
    tag = os.environ.get("RUN_TAG", "run")
    curve = [{"step": h["step"], "loss": h["loss"]} for h in trainer.state.log_history if "loss" in h]
    json.dump(curve, open(f"loss_{tag}.json", "w"), indent=2)
# sp_fsdp.yaml
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
mixed_precision: 'no'      # fp32 -> clean curve comparison (no bf16 reduction noise)
num_machines: 1
num_processes: 4          # 4 GPUs total
machine_rank: 0
rdzv_backend: static
same_network: true

parallelism_config:
  parallelism_config_dp_shard_size: 2  # FSDP data-parallel shard dim
  parallelism_config_sp_size: 2        # 2 (dp_shard) x 2 (sp) = 4 = num_processes
  parallelism_config_sp_backend: accelerate

fsdp_config:
  fsdp_version: 2
  fsdp_reshard_after_forward: true
RUN_TAG=sp2 accelerate launch --config_file sp_fsdp.yaml  sp_loss_repro.py   # 4 GPUs, dp2/sp2
...
{'loss': '10.58', 'grad_norm': '3536', 'learning_rate': '0', 'epoch': '0.007812'}                         
{'loss': '8.94', 'grad_norm': '1696', 'learning_rate': '0', 'epoch': '0.01562'}                           
{'loss': '11.9', 'grad_norm': '3232', 'learning_rate': '0', 'epoch': '0.02344'}                           
{'loss': '11.84', 'grad_norm': '3456', 'learning_rate': '0', 'epoch': '0.03125'}                          
{'loss': '8.395', 'grad_norm': '1816', 'learning_rate': '0', 'epoch': '0.03906'}
...

# dp2_sp1.yaml
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
mixed_precision: 'no'      # fp32 -> clean curve comparison (no bf16 reduction noise)
num_machines: 1
num_processes: 2          # 2 GPUs (sp=1, so dp world is the same 2 as the sp run)
machine_rank: 0
rdzv_backend: static
same_network: true

parallelism_config:
  parallelism_config_dp_shard_size: 2  # 2 (dp_shard) x 1 (sp) = 2 = num_processes
  parallelism_config_sp_size: 1        # SP disabled -> each GPU sees the full sequence

fsdp_config:
  fsdp_version: 2
RUN_TAG=sp1 CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file dp2_sp1.yaml sp_loss_repro.py
...
{'loss': '5.288', 'grad_norm': '1768', 'learning_rate': '0', 'epoch': '0.01562'}                          
{'loss': '4.172', 'grad_norm': '796', 'learning_rate': '0', 'epoch': '0.03125'}                           
{'loss': '5.491', 'grad_norm': '1504', 'learning_rate': '0', 'epoch': '0.04688'}                          
{'loss': '5.918', 'grad_norm': '1728', 'learning_rate': '0', 'epoch': '0.0625'}                           
{'loss': '4.198', 'grad_norm': '908', 'learning_rate': '0', 'epoch': '0.07812'}
...

@qgallouedec

Copy link
Copy Markdown
Member

#46855

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants