Skip to content

LAION-AI/jax-dacvae-echotts

Repository files navigation

JAX DACVAE EchoTTS

JAX/Flax training and inference code for EchoTTS, optimized for Google Cloud TPUs.

This is the JAX implementation of the Echo Diffusion Transformer (DiT) text-to-speech model, using DACVAE (Descript Audio Codec VAE) as the latent audio codec. The training code is designed to run on TPU v4 pods (tested on v4-64, 32 chips / 8 hosts) with FSDP sharding, streaming data pipeline, and multi-host checkpointing.

Architecture

Echo is a flow-matching Diffusion Transformer that generates speech from text:

  • Text Encoder: Byte-level (UTF-8, vocab=256) Transformer encoder trained from scratch
  • Speaker Encoder: Latent-space encoder that processes DACVAE speaker reference audio
  • Decoder: 24-layer Transformer with JointSelfCrossAttention (self + text cross + speaker cross)
  • Conditioning: AdaLN (Adaptive Layer Normalization) with rank-256 decomposition, SwiGLU MLP
  • Attention: QK-norm (RMSNorm), half-head RoPE, gated output (sigmoid(gate) * attn)
  • Output: Zero-initialized projection to latent space

DACVAE Codec

DACVAE (Facebook) converts audio to/from 128-dimensional latent representations at 25 frames per second, with 48kHz audio output. The model operates entirely in this latent space — text goes in, latent frames come out, DACVAE decodes to waveform.

Model Configurations

Config Params Backbone Encoder Text Encoder
1.3B 1,304M 1536-dim, 24L, 12H 1024-dim, 10L, 8H 1024-dim, 10L, 8H
840M ~800M 1280-dim, 24L, 10H 768-dim, 8L, 6H 768-dim, 8L, 6H

Training

Prerequisites

  • Google Cloud TPU v4 pod (v4-64 recommended: 32 chips, 8 hosts)
  • Python 3.10+
  • pip install -r requirements.txt
  • DACVAE weights: huggingface-cli download facebook/dacvae-watermarked weights.pth
  • Silence latent for padding: precomputed 30s silence through DACVAE encoder

Dataset Format

Training data is stored as tar files on HuggingFace Hub. Each tar contains:

  • {key}.npy — DACVAE latent (N_frames, 128) float16
  • {key}.ref.npy — Speaker reference latent (N_frames, 128) float16
  • {key}.json — Metadata with "text" field

Launch Training

# Deploy to all workers
gcloud compute tpus tpu-vm scp train_pretrain.py NODE:/path/train_pretrain.py \
    --zone=ZONE --worker=all
gcloud compute tpus tpu-vm scp model.py NODE:/path/model.py \
    --zone=ZONE --worker=all

# Launch on all workers simultaneously
gcloud compute tpus tpu-vm ssh NODE --zone=ZONE --worker=all \
    --command="bash /path/launch.sh"

All 8 workers must start within seconds of each other. jax.distributed.initialize() synchronizes them. JIT compilation takes 2-10 minutes on first step.

Training Configuration

Default configuration in train_pretrain.py:

  • Global batch size: 256 (8 per TPU chip)
  • Optimizer: AdamW (lr=1e-4 peak, cosine schedule, 5% warmup, b1=0.9, b2=0.99)
  • CFG dropout: 10% speaker, 10% text
  • Flow matching: x_t = (1-t)*x_0 + t*noise, target = noise - x_0
  • Timestep sampling: Stratified logit-normal
  • Checkpoints: Every 5000 steps, includes optimizer state for lossless resume

Checkpoint Format

{
    "params": pytree_of_numpy_arrays,   # model weights
    "opt_state": pytree_of_numpy_arrays, # Adam momentum + variance (for lossless resume)
    "step": int,
    "samples_seen": int,
}

Monitoring

monitor.py serves an HTTP dashboard on port 8080 with live loss/throughput charts. Expose via Cloudflare tunnel:

cloudflared tunnel --url http://localhost:8080

Disk Management

disk_watchdog.py runs on all workers to prevent disk-full crashes. It auto-cleans HuggingFace caches, old tars, and old checkpoints:

gcloud compute tpus tpu-vm ssh NODE --zone=ZONE --worker=all \
    --command="nohup python3 disk_watchdog.py > /tmp/watchdog.log 2>&1 &"

Inference

# Without speaker reference (model generates a random voice)
python inference.py \
    --checkpoint checkpoints/step_0935000.pkl \
    --text "Hello, this is a test of the Echo text to speech system." \
    --output output.wav

# With speaker reference (voice cloning)
python inference.py \
    --checkpoint checkpoints/step_0935000.pkl \
    --text "Hello, this is a test of the Echo text to speech system." \
    --speaker_ref reference.mp3 \
    --output output.wav \
    --cfg_scale 3.0 \
    --steps 32

Inference Parameters

Parameter Default Description
--cfg_scale 3.0 Classifier-free guidance scale
--steps 32 Number of Euler ODE steps
--duration 15.0 Max output duration (seconds)
--seed 42 Random seed for reproducibility

Key Implementation Notes

  • TPU v4 batch size: BS=256 (8/device) is optimal for 1.3B params. BS=448 causes silent HBM thrashing (5x slower, no error).
  • Process index != worker index: On TPU pods, jax.process_index() is shuffled. Process 0 (which saves checkpoints) may be on any worker.
  • Checkpoint resume: Flax LogicallyPartitioned wrappers require careful pytree reconstruction when loading. See train_pretrain.py resume logic.
  • state.step must be a JAX array: Using a Python int causes recompilation on every step (~600s each).
  • TensorFlow mock: Required at import time to prevent HuggingFace from pulling in TF (conflicts with JAX on TPU).

References

License

This project is released for research purposes. See the original EchoTTS repository for license details.

About

JAX/TPU training code for EchoTTS with DACVAE latent codec

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors