Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
4c4a7b0
Add probabilistic pretrain and GRPO RL training, keep backward compat…
hcsolakoglu Jan 11, 2026
6a6dac1
Add probabilistic pretrain and GRPO RL training, keep backward compat…
hcsolakoglu Jan 11, 2026
7db2ef9
Warn on missing gaussian ln_sig head during soft load
hcsolakoglu Jan 11, 2026
63fa593
Update WeSpeaker fetch script to use HF archives
hcsolakoglu Jan 11, 2026
e6c0e80
Add 8-bit optimizer support for GRPO and ignore checkpoints
hcsolakoglu Jan 11, 2026
37c3860
Document RL stages and improve GRPO logging
hcsolakoglu Jan 11, 2026
9d10435
Log GRPO metrics on main process
hcsolakoglu Jan 11, 2026
2ccf797
Harden RL deps and training setup
hcsolakoglu Jan 11, 2026
748b80a
Add tests for trainer and wespeaker guardrails
hcsolakoglu Jan 11, 2026
6ca8ccd
Add opt-in per-sample prompt length for GRPO
hcsolakoglu Jan 11, 2026
f40d3e5
Keep ref model eval in GRPO forward_rl
hcsolakoglu Jan 11, 2026
f16681a
Fix RL resume and pin RL deps
hcsolakoglu Jan 11, 2026
f0636e3
Add reward provider device tests
hcsolakoglu Jan 11, 2026
3dd00f5
Improve GRPO logging config and cadence
hcsolakoglu Jan 11, 2026
0681b27
Default test audio pack to HF dummy dataset
hcsolakoglu Jan 11, 2026
439ec1a
Document RL smoke test workflow
hcsolakoglu Jan 11, 2026
a6241d4
Add trackio logging option
hcsolakoglu Jan 11, 2026
6a73693
Clarify reward metric names in logs
hcsolakoglu Jan 11, 2026
732c392
Add reward correctness tests
hcsolakoglu Jan 11, 2026
240ed9f
Document longer GPU run and better dataset
hcsolakoglu Jan 11, 2026
6172de3
Add colab RL pipeline and char-level WER option
hcsolakoglu Jan 11, 2026
1360365
Add FunASR ref_source option for audio-based WER
hcsolakoglu Jan 11, 2026
ac3a777
Document colab RL run notes and improve wandb logging fallback
hcsolakoglu Jan 11, 2026
9005100
Stabilize ruff import sorting for wandb
hcsolakoglu Jan 12, 2026
3926e1b
Add opt-in RL prompt bounds and steps+1
hcsolakoglu Jan 12, 2026
f4a7956
Align reward defaults and add clarity comments
hcsolakoglu Jan 12, 2026
0c33d73
Document latest RL branch changes and config defaults
hcsolakoglu Jan 12, 2026
c4a5104
Fix RL sample logging and config wiring
hcsolakoglu Jan 12, 2026
bd3a3ce
Fix range prompt handling in GRPO
hcsolakoglu Jan 12, 2026
f7edea1
Document recommended RL opt-ins
hcsolakoglu Jan 12, 2026
9e21ab7
Add GRPO stability opt-ins and sampler generator
hcsolakoglu Jan 12, 2026
9d5c4c7
Document sampler memory optimization
hcsolakoglu Jan 12, 2026
9405b53
Add opt-in KL alignment and strict no-ref audio
hcsolakoglu Jan 12, 2026
db9937e
feat(rl): add opt-in legacy length check and max duration config
hcsolakoglu Jan 12, 2026
9511fb1
Fix GRPO accumulation, skip-grad, and reward ref
hcsolakoglu Jan 12, 2026
89ae389
Add tests for GRPO skip-grad and refs
hcsolakoglu Jan 12, 2026
aebeb10
Init gaussian ln_sig head and add perf test
hcsolakoglu Jan 12, 2026
12aa12e
Add probabilistic pretrain and GRPO RL training, keep backward compat…
hcsolakoglu Jan 12, 2026
3951a10
Apply ruff formatting
hcsolakoglu Jan 12, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# Customed
.vscode/
tests/
runs/
data/
ckpts/
checkpoints/
wandb/
.wandb/
.pre-commit-cache/
results/
tests/assets/audio_pack/

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
274 changes: 274 additions & 0 deletions README_RL.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
# RL Two-Stage Training (Warmup + GRPO)

This guide covers the RL workflow for F5-TTS:
- Stage 1: Gaussian NLL warmup (probabilistic head pretrain)
- Stage 2: GRPO fine-tune with reward models

It includes dataset layout, reward model setup, and minimal launch commands.

## Requirements

- Use your uv venv (examples use `./.venv/bin/python`; adjust if your venv lives elsewhere).
- Install RL extras:
```bash
./.venv/bin/python -m pip install -e ".[rl]"
```
This includes `huggingface_hub` for the fetch scripts.
- Optional trackers:
```bash
./.venv/bin/python -m pip install -e ".[trackio]"
```
- Reward models:
- FunASR SenseVoiceSmall
- WeSpeaker cnceleb_resnet34 (fbank frontend)

## Dataset layout

`train.py` uses `data/<dataset_name>_<tokenizer>` for `CustomDataset`:
- `data/<dataset>_<tokenizer>/raw` (HF dataset saved via `save_to_disk`)
- `data/<dataset>_<tokenizer>/duration.json`

### Option A: quick dummy dataset (HF internal)

```bash
./.venv/bin/python - <<'PY'
from datasets import load_dataset, Dataset
from pathlib import Path
import soundfile as sf
import json

out_root = Path("data/mini_rl_custom")
wav_dir = out_root / "wavs"
wav_dir.mkdir(parents=True, exist_ok=True)

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", split="validation")
text_key = "text" if "text" in ds.column_names else ds.column_names[-1]

items = []
durations = []
for idx, row in enumerate(ds):
if idx >= 8:
break
audio = row["audio"]["array"]
sr = row["audio"]["sampling_rate"]
text = row[text_key]
wav_path = wav_dir / f"sample_{idx:02d}.wav"
sf.write(wav_path, audio, sr)
duration = len(audio) / sr
items.append({"audio_path": str(wav_path), "text": text, "duration": duration})
durations.append(duration)

Dataset.from_list(items).save_to_disk(str(out_root / "raw"))
with (out_root / "duration.json").open("w", encoding="utf-8") as f:
json.dump({"duration": durations}, f)
PY
```

Use `datasets.name=mini_rl` with `model.tokenizer=custom` so the loader reads
`data/mini_rl_custom`.

### Option B: small LibriSpeech subset (streaming)

```bash
./.venv/bin/python - <<'PY'
from datasets import load_dataset, Dataset
from pathlib import Path
import json
import soundfile as sf

out_root = Path("data/mini_rl_custom")
wav_dir = out_root / "wavs"
wav_dir.mkdir(parents=True, exist_ok=True)

ds = load_dataset("librispeech_asr", "clean", split="train.100", streaming=True)
items = []
durations = []
for idx, row in enumerate(ds):
if idx >= 64:
break
audio = row["audio"]["array"]
sr = row["audio"]["sampling_rate"]
text = row.get("text", "")
wav_path = wav_dir / f"sample_{idx:02d}.wav"
sf.write(wav_path, audio, sr)
duration = len(audio) / sr
items.append({"audio_path": str(wav_path), "text": text, "duration": duration})
durations.append(duration)

Dataset.from_list(items).save_to_disk(str(out_root / "raw"))
with (out_root / "duration.json").open("w", encoding="utf-8") as f:
json.dump({"duration": durations}, f)
PY
```

## Reward model assets

FunASR:
```bash
./.venv/bin/python -m f5_tts.scripts.fetch_reward_asr_model \
--cache_dir checkpoints/funasr/SenseVoiceSmall
```

WeSpeaker (HF archive; fbank frontend):
```bash
./.venv/bin/python -m f5_tts.scripts.fetch_reward_spk_model \
--cache_dir checkpoints/wespeaker/cnceleb_resnet34
```

WeSpeaker `model_dir` should point to:
```
checkpoints/wespeaker/cnceleb_resnet34/cnceleb_resnet34
```
Alternatively, set `WESPEAKER_HOME` to the model folder and omit `model_dir` in the config.

## Stage 1: Warmup (gaussian_nll)

Download a pretrained base checkpoint and place it in the warmup dir:

```bash
./.venv/bin/python - <<'PY'
from cached_path import cached_path
from pathlib import Path
import shutil

ckpt = cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors")
out_dir = Path("ckpts/mini_rl_warmup")
out_dir.mkdir(parents=True, exist_ok=True)
dst = out_dir / ("pretrained_" + Path(ckpt).name)
if not dst.exists():
shutil.copy2(ckpt, dst)
print(dst)
PY
```

Warmup command:

```bash
CUDA_VISIBLE_DEVICES=0 \
./.venv/bin/python -m f5_tts.train.train -cn F5TTS_v1_Base \
datasets.name=mini_rl datasets.batch_size_per_gpu=2 datasets.batch_size_type=sample datasets.num_workers=2 \
model.tokenizer=custom model.tokenizer_path=$PWD/data/Emilia_ZH_EN_pinyin/vocab.txt \
model.output_dist=gaussian model.objective=gaussian_nll model.sample_from_dist=false model.use_rl_head=true \
model.arch.checkpoint_activations=false \
optim.epochs=1 optim.learning_rate=1e-5 optim.num_warmup_updates=0 optim.grad_accumulation_steps=1 optim.bnb_optimizer=true \
optim.mixed_precision=auto optim.tf32=true \
ckpts.save_dir=ckpts/mini_rl_warmup ckpts.save_per_updates=1000 ckpts.keep_last_n_checkpoints=0 ckpts.log_samples=false ckpts.logger=null
```

Checkpoint behavior:
- If no `model_*.pt` exists, the trainer loads `pretrained_*.safetensors` from the save dir.
- It writes `model_last.pt` and `model_<update>.pt`.

## Stage 2: GRPO

Copy the warmup checkpoint into the GRPO directory:

```bash
mkdir -p ckpts/mini_rl_grpo
cp ckpts/mini_rl_warmup/model_last.pt ckpts/mini_rl_grpo/model_last.pt
```

Or, skip the copy and point GRPO directly at the warmup checkpoint:

```bash
ckpts.init_from=$PWD/ckpts/mini_rl_warmup/model_last.pt
```

GRPO command:

```bash
CUDA_VISIBLE_DEVICES=0 PYTORCH_ALLOC_CONF=expandable_segments:True \
./.venv/bin/python -m f5_tts.train.train_rl \
datasets.name=mini_rl datasets.batch_size_per_gpu=1 datasets.batch_size_type=sample datasets.num_workers=2 \
model.tokenizer=custom model.tokenizer_path=$PWD/data/Emilia_ZH_EN_pinyin/vocab.txt \
model.output_dist=gaussian model.objective=grpo model.use_rl_head=true model.arch.checkpoint_activations=false \
optim.epochs=1 optim.learning_rate=1e-6 optim.num_warmup_updates=0 optim.grad_accumulation_steps=1 optim.bnb_optimizer=true \
optim.mixed_precision=auto optim.tf32=true \
ckpts.save_dir=ckpts/mini_rl_grpo ckpts.save_per_updates=1000 ckpts.keep_last_n_checkpoints=0 ckpts.log_samples=false ckpts.logger=wandb \
rl.steps=30 rl.repeat_count=1 rl.mini_repeat_count=1 rl.prompt_frac_range='[0.1,0.3]' rl.prompt_length_mode=min \
rl.cfg_strength=2.0 rl.sway_sampling_coef=-1.0 rl.kl_weight=1.0 \
rl.ref_model_ckpt=$PWD/ckpts/mini_rl_warmup/model_last.pt \
rl.rewards.providers.0.config.model_dir=$PWD/checkpoints/wespeaker/cnceleb_resnet34/cnceleb_resnet34 \
rl.rewards.providers.0.config.device=cpu \
rl.rewards.providers.1.config.model_id=$PWD/checkpoints/funasr/SenseVoiceSmall \
rl.rewards.providers.1.config.device=cpu
```

Sample logging:
- Set `ckpts.log_samples=true` to save `update_*_gen.wav` / `update_*_ref.wav` under `ckpts/.../samples`
at each `ckpts.save_per_updates` interval.

## RL knobs (quick reference)

- `rl.steps`: number of diffusion/inference steps per GRPO rollout. Higher values improve
audio quality and reward signal (ASR/WER), but increase compute and memory.
- `rl.steps_plus_one`: opt-in to use `steps + 1` integration points in `forward_rl`. Default is `false`
for F5R parity; set `true` if you want RL rollouts to match non-RL step count.
- `rl.skip_grad_prob`: probability to skip gradient tracking per ODE step (default: `0.05`). Set `0.95`
for strict F5R parity.
- `rl.max_grad_steps`: optional cap on the number of ODE steps kept for GRPO (default: `null`).
- `rl.prompt_length_mode`: `min` (F5R parity), `per_sample`, or `range`. `range` uses the sampled
fraction directly so prompt length respects the lower bound in `prompt_frac_range`.
- `rl.kl_eps`: add a small epsilon to the KL denominator for extra numerical stability (default: 0.0).
- `rl.density_eps`: add a small epsilon to Gaussian density weighting for stability (default: 0.0).
- `rl.align_kl_steps`: share the ODE skip mask between policy/ref rollouts for a less noisy KL (default: `false`).
- `rl.max_duration`: Maximum allowed mel frames (default: 4096). Samples exceeding this are skipped to prevent truncation.
- `rl.legacy_length_check`: If `true`, enables legacy filtering where samples with `text_len > mel_len` are skipped (F5R parity). Default `false` to fix this behavior.
- `rl.reward_ref_source`: `auto | audio_path | mel` (default: `auto`). `auto` uses dataset audio paths for the
speaker reference when available; `mel` forces vocoder decode for ref audio.
- `rl.reward_ref_cache_size`: number of reference audio entries to cache in memory (default: `128`).
- `wer_mode`: `char | word` (default: `char`, matching F5R).
- `ref_source`: `text | audio` (default: `text`; set `audio` to match ASR-vs-ASR reward in F5R).

## Recommended opt-ins (deviations from F5R defaults)

Defaults match F5R except where noted (skip-grad, ref audio source). If you want the more robust behavior we found during integration, opt in:
- `rl.prompt_length_mode=range` (or `per_sample`) to avoid collapsing prompt lengths to the batch minimum and to honor `prompt_frac_range` lower bounds.
- `rl.steps_plus_one=true` to align RL rollouts with the non‑RL step count (`steps + 1` integration points).
- `rl.align_kl_steps=true` to keep KL timesteps aligned when the ODE skips gradients.
- `rl.kl_eps=1e-6` and `rl.density_eps=1e-6` if you see NaN/inf in KL or advantage weighting (keeps defaults at parity).
- `rl.skip_grad_prob=0.95` if you need strict F5R parity for skip‑grad frequency.
- `rl.rewards.providers.1.config.ref_source=audio` if you want ASR‑vs‑ASR reward instead of text‑vs‑ASR.
- `ckpts.log_samples=true` for debugging; writes sample WAVs under `ckpts/.../samples` at each save interval.

## Logging

W&B logs include:
- `loss`, `loss/kl`, `loss/pro_adv`
- `reward/mean`, `reward/std`, `reward/min`, `reward/max`
- `reward/speaker_similarity/cosine`
- `reward/asr/char_error_rate` or `reward/asr/word_error_rate` (depends on `wer_mode`)
- `stats/skip_ratio`, `stats/kept_steps`
- Sample audio when `ckpts.log_samples=true` (`samples/gen`, `samples/ref`)

Trackio (drop-in alternative):
```bash
./.venv/bin/python -m pip install -e ".[trackio]"
```
Then set `ckpts.logger=trackio` and view logs locally with:
```bash
trackio show
```

## Implementation parity notes

These details intentionally match the F5R reference code:
- Gaussian loss adds `t^2 * ln_sig` to regularize variance over time.
- GRPO uses Gaussian density weighting (not log-prob) for advantage shaping.
- ODE integration can skip gradients on some steps for speed; tune `rl.skip_grad_prob` if needed.

## Additional implementation notes

- DynamicBatchSampler now yields repeated batches via a generator to avoid large
in‑memory lists when `repeat_count`/`mini_repeat_count` are large; this is compatible
with Accelerate’s batch sharding and keeps deterministic epoch shuffling intact.
- `forward_rl(..., no_ref_audio=True)` keeps conditioning unless you pass
`strict_no_ref_audio=True` (opt‑in to fully zero conditioning before building `step_cond`).

## Troubleshooting

- `num_workers=0` is supported; `persistent_workers` is only enabled when `num_workers > 0`.
- If WeSpeaker fails to import, install the GitHub source and use an fbank model.
- If FunASR is missing, install `.[reward_funasr]` or `funasr==1.3.0`.
- If WER is flat, increase `rl.steps` (very low values often produce poor audio).
- On low disk, keep only the final checkpoint: `ckpts.keep_last_n_checkpoints=0`.
19 changes: 18 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,28 @@ dependencies = [
[project.optional-dependencies]
eval = [
"faster_whisper==0.10.1",
"funasr",
"funasr==1.3.0",
"jiwer",
"modelscope",
"zhconv",
"zhon",
]
rl = [
"funasr==1.3.0",
"huggingface_hub>=0.23.0",
# WeSpeaker pinned for compatibility/reproducibility; update hash if upstream changes or tags a release.
"wespeaker @ git+https://github.com/wenet-e2e/wespeaker.git@8f53b6485d9f88a207bd17e7f8dba899495ec794",
]
reward_funasr = [
"funasr==1.3.0",
]
reward_wespeaker = [
# Keep the WeSpeaker commit in sync with the `rl` extra.
"wespeaker @ git+https://github.com/wenet-e2e/wespeaker.git@8f53b6485d9f88a207bd17e7f8dba899495ec794",
]
trackio = [
"trackio",
]

[project.urls]
Homepage = "https://github.com/SWivid/F5-TTS"
Expand All @@ -63,3 +79,4 @@ Homepage = "https://github.com/SWivid/F5-TTS"
"f5-tts_infer-gradio" = "f5_tts.infer.infer_gradio:main"
"f5-tts_finetune-cli" = "f5_tts.train.finetune_cli:main"
"f5-tts_finetune-gradio" = "f5_tts.train.finetune_gradio:main"
"f5-tts_train-rl" = "f5_tts.train.train_rl:main"
1 change: 1 addition & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ dummy-variable-rgx = "^_.*$"
[lint.isort]
force-single-line = false
lines-after-imports = 2
known-third-party = ["wandb"]
19 changes: 17 additions & 2 deletions src/f5_tts/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,13 @@ def __init__(
):
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch
model_arc = OmegaConf.to_container(model_cfg.model.arch, resolve=True)
output_dist = model_cfg.model.get("output_dist", "deterministic")
sample_from_dist = model_cfg.model.get("sample_from_dist", False)
if "output_dist" not in model_arc:
model_arc["output_dist"] = output_dist
if "use_rl_head" not in model_arc and "use_rl_head" in model_cfg.model:
model_arc["use_rl_head"] = model_cfg.model.use_rl_head

self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
self.target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
Expand Down Expand Up @@ -80,7 +86,16 @@ def __init__(
cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}", cache_dir=hf_cache_dir)
)
self.ema_model = load_model(
model_cls, model_arc, ckpt_file, self.mel_spec_type, vocab_file, self.ode_method, self.use_ema, self.device
model_cls,
model_arc,
ckpt_file,
self.mel_spec_type,
vocab_file,
self.ode_method,
self.use_ema,
self.device,
output_dist=output_dist,
sample_from_dist=sample_from_dist,
)

def transcribe(self, ref_audio, language=None):
Expand Down
7 changes: 6 additions & 1 deletion src/f5_tts/configs/E2TTS_Base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ model:
tokenizer: pinyin
tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
backbone: UNetT
output_dist: deterministic
objective: mse
sample_from_dist: False
use_rl_head: False
arch:
dim: 1024
depth: 24
Expand All @@ -46,4 +50,5 @@ ckpts:
save_per_updates: 50000 # save checkpoint per updates
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
last_per_updates: 5000 # save last checkpoint per updates
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
allow_extra_keys: False
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
Loading