JAX/Flax training and inference code for EchoTTS, optimized for Google Cloud TPUs.
This is the JAX implementation of the Echo Diffusion Transformer (DiT) text-to-speech model, using DACVAE (Descript Audio Codec VAE) as the latent audio codec. The training code is designed to run on TPU v4 pods (tested on v4-64, 32 chips / 8 hosts) with FSDP sharding, streaming data pipeline, and multi-host checkpointing.
Echo is a flow-matching Diffusion Transformer that generates speech from text:
- Text Encoder: Byte-level (UTF-8, vocab=256) Transformer encoder trained from scratch
- Speaker Encoder: Latent-space encoder that processes DACVAE speaker reference audio
- Decoder: 24-layer Transformer with JointSelfCrossAttention (self + text cross + speaker cross)
- Conditioning: AdaLN (Adaptive Layer Normalization) with rank-256 decomposition, SwiGLU MLP
- Attention: QK-norm (RMSNorm), half-head RoPE, gated output (
sigmoid(gate) * attn) - Output: Zero-initialized projection to latent space
DACVAE (Facebook) converts audio to/from 128-dimensional latent representations at 25 frames per second, with 48kHz audio output. The model operates entirely in this latent space — text goes in, latent frames come out, DACVAE decodes to waveform.
| Config | Params | Backbone | Encoder | Text Encoder |
|---|---|---|---|---|
| 1.3B | 1,304M | 1536-dim, 24L, 12H | 1024-dim, 10L, 8H | 1024-dim, 10L, 8H |
| 840M | ~800M | 1280-dim, 24L, 10H | 768-dim, 8L, 6H | 768-dim, 8L, 6H |
- Google Cloud TPU v4 pod (v4-64 recommended: 32 chips, 8 hosts)
- Python 3.10+
pip install -r requirements.txt- DACVAE weights:
huggingface-cli download facebook/dacvae-watermarked weights.pth - Silence latent for padding: precomputed 30s silence through DACVAE encoder
Training data is stored as tar files on HuggingFace Hub. Each tar contains:
{key}.npy— DACVAE latent (N_frames, 128) float16{key}.ref.npy— Speaker reference latent (N_frames, 128) float16{key}.json— Metadata with"text"field
# Deploy to all workers
gcloud compute tpus tpu-vm scp train_pretrain.py NODE:/path/train_pretrain.py \
--zone=ZONE --worker=all
gcloud compute tpus tpu-vm scp model.py NODE:/path/model.py \
--zone=ZONE --worker=all
# Launch on all workers simultaneously
gcloud compute tpus tpu-vm ssh NODE --zone=ZONE --worker=all \
--command="bash /path/launch.sh"All 8 workers must start within seconds of each other. jax.distributed.initialize() synchronizes them. JIT compilation takes 2-10 minutes on first step.
Default configuration in train_pretrain.py:
- Global batch size: 256 (8 per TPU chip)
- Optimizer: AdamW (lr=1e-4 peak, cosine schedule, 5% warmup, b1=0.9, b2=0.99)
- CFG dropout: 10% speaker, 10% text
- Flow matching:
x_t = (1-t)*x_0 + t*noise, target =noise - x_0 - Timestep sampling: Stratified logit-normal
- Checkpoints: Every 5000 steps, includes optimizer state for lossless resume
{
"params": pytree_of_numpy_arrays, # model weights
"opt_state": pytree_of_numpy_arrays, # Adam momentum + variance (for lossless resume)
"step": int,
"samples_seen": int,
}monitor.py serves an HTTP dashboard on port 8080 with live loss/throughput charts. Expose via Cloudflare tunnel:
cloudflared tunnel --url http://localhost:8080disk_watchdog.py runs on all workers to prevent disk-full crashes. It auto-cleans HuggingFace caches, old tars, and old checkpoints:
gcloud compute tpus tpu-vm ssh NODE --zone=ZONE --worker=all \
--command="nohup python3 disk_watchdog.py > /tmp/watchdog.log 2>&1 &"# Without speaker reference (model generates a random voice)
python inference.py \
--checkpoint checkpoints/step_0935000.pkl \
--text "Hello, this is a test of the Echo text to speech system." \
--output output.wav
# With speaker reference (voice cloning)
python inference.py \
--checkpoint checkpoints/step_0935000.pkl \
--text "Hello, this is a test of the Echo text to speech system." \
--speaker_ref reference.mp3 \
--output output.wav \
--cfg_scale 3.0 \
--steps 32| Parameter | Default | Description |
|---|---|---|
--cfg_scale |
3.0 | Classifier-free guidance scale |
--steps |
32 | Number of Euler ODE steps |
--duration |
15.0 | Max output duration (seconds) |
--seed |
42 | Random seed for reproducibility |
- TPU v4 batch size: BS=256 (8/device) is optimal for 1.3B params. BS=448 causes silent HBM thrashing (5x slower, no error).
- Process index != worker index: On TPU pods,
jax.process_index()is shuffled. Process 0 (which saves checkpoints) may be on any worker. - Checkpoint resume: Flax
LogicallyPartitionedwrappers require careful pytree reconstruction when loading. Seetrain_pretrain.pyresume logic. state.stepmust be a JAX array: Using a Python int causes recompilation on every step (~600s each).- TensorFlow mock: Required at import time to prevent HuggingFace from pulling in TF (conflicts with JAX on TPU).
- EchoTTS — Original Echo TTS model
- DACVAE (Descript Audio Codec) — Audio codec
- DACVAE Weights (Facebook) — Pretrained DACVAE
- Flow Matching — Lipman et al., 2022
- Diffusion Transformers (DiT) — Peebles & Xie, 2023
- Adaptive Layer Normalization — Used for timestep and conditioning injection
This project is released for research purposes. See the original EchoTTS repository for license details.