Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 25 additions & 21 deletions audiodit/modeling_audiodit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn
from torch.nn.utils import weight_norm
from torch.nn.utils.rnn import pad_sequence
Expand Down Expand Up @@ -53,15 +54,13 @@ def odeint_euler(fn, y0, t):
t: 1-D tensor of time steps (must be monotonically increasing)

Returns:
Tensor of shape `(len(t), *y0.shape)` containing the trajectory.
Final state tensor with the same shape as *y0*.
"""
ys = [y0]
y = y0
for i in range(len(t) - 1):
dt = t[i + 1] - t[i]
y = y + fn(t[i], y) * dt
ys.append(y)
return torch.stack(ys)
return y


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -298,7 +297,7 @@ def __init__(self, dim: int, heads: int, dim_head: int, dropout: float = 0.0, bi
if qk_norm:
self.q_norm = AudioDiTRMSNorm(self.inner_dim, eps=eps)
self.k_norm = AudioDiTRMSNorm(self.inner_dim, eps=eps)
self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, dim, bias=bias), nn.Dropout(dropout)])
self.to_out = nn.Sequential(nn.Linear(self.inner_dim, dim, bias=bias), nn.Dropout(dropout))

def forward(self, x: torch.Tensor, mask: torch.BoolTensor | None = None, rope: tuple | None = None) -> torch.Tensor:
batch_size = x.shape[0]
Expand All @@ -320,9 +319,7 @@ def forward(self, x: torch.Tensor, mask: torch.BoolTensor | None = None, rope: t
attn_mask = mask.unsqueeze(1).unsqueeze(1).expand(batch_size, self.heads, query.shape[-2], key.shape[-2])
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
x = x.transpose(1, 2).reshape(batch_size, -1, self.inner_dim).to(query.dtype)
x = self.to_out[0](x)
x = self.to_out[1](x)
return x
return self.to_out(x)


class AudioDiTCrossAttention(nn.Module):
Expand All @@ -337,7 +334,7 @@ def __init__(self, q_dim: int, kv_dim: int, heads: int, dim_head: int, dropout:
if qk_norm:
self.q_norm = AudioDiTRMSNorm(self.inner_dim, eps=eps)
self.k_norm = AudioDiTRMSNorm(self.inner_dim, eps=eps)
self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, q_dim, bias=bias), nn.Dropout(dropout)])
self.to_out = nn.Sequential(nn.Linear(self.inner_dim, q_dim, bias=bias), nn.Dropout(dropout))

def forward(
self, x: torch.Tensor, cond: torch.Tensor, mask: torch.BoolTensor | None = None,
Expand All @@ -364,9 +361,7 @@ def forward(
attn_mask = attn_mask.expand(batch_size, self.heads, query.shape[-2], key.shape[-2])
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
x = x.transpose(1, 2).reshape(batch_size, -1, self.inner_dim).to(query.dtype)
x = self.to_out[0](x)
x = self.to_out[1](x)
return x
return self.to_out(x)


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -444,7 +439,6 @@ def forward(
adaln_out = self.adaln_mlp(norm_cond)
gate_sa, scale_sa, shift_sa, gate_ffn, scale_ffn, shift_ffn = torch.chunk(adaln_out, 6, dim=-1)
else:
from einops import rearrange
adaln_out = adaln_global_out + rearrange(self.adaln_scale_shift, "f -> 1 f")
gate_sa, scale_sa, shift_sa, gate_ffn, scale_ffn, shift_ffn = torch.chunk(adaln_out, 6, dim=-1)

Expand Down Expand Up @@ -997,6 +991,8 @@ def forward(
attention_mask: torch.LongTensor | None = None,
text_embedding: torch.FloatTensor | None = None,
prompt_audio: torch.FloatTensor | None = None,
prompt_latent: torch.FloatTensor | None = None,
prompt_duration_frames: int | None = None,
duration: int | None = None,
steps: int = 16,
cfg_strength: float = 4.0,
Expand All @@ -1010,6 +1006,11 @@ def forward(
attention_mask: Attention mask ``(batch, seq_len)``.
text_embedding: Pre-computed text embeddings ``(batch, seq_len, dim)``. Alternative to input_ids.
prompt_audio: Optional prompt audio ``(batch, 1, num_samples)`` for voice cloning.
prompt_latent: Pre-encoded prompt latent ``(batch, num_frames, latent_dim)`` from
``encode_prompt_audio()``. Use instead of ``prompt_audio`` to avoid redundant
VAE encoding when the latent is already available.
prompt_duration_frames: Number of prompt latent frames. Required when
``prompt_latent`` is provided.
duration: Target duration in latent frames (prompt + gen). If None, uses max_wav_duration.
steps: Number of ODE Euler steps (default 16).
cfg_strength: Guidance strength for CFG/APG (default 4.0).
Expand Down Expand Up @@ -1040,7 +1041,10 @@ def forward(
batch = text_condition.shape[0]

# ── prompt audio encoding ─────────────────────────────────────
if prompt_audio is not None:
if prompt_latent is not None:
prompt_latent = prompt_latent.to(device)
prompt_dur = prompt_duration_frames
elif prompt_audio is not None:
prompt_latent, prompt_dur = self.encode_prompt_audio(prompt_audio)
else:
prompt_latent = torch.empty(batch, 0, self.config.latent_dim, device=device)
Expand All @@ -1061,7 +1065,7 @@ def forward(
neg_text_len = text_condition_len

latent_len = prompt_dur
if prompt_audio is not None:
if prompt_dur > 0:
gen_len = max_dur - latent_len
latent_cond = F.pad(prompt_latent, (0, 0, 0, gen_len))
empty_latent_cond = torch.zeros_like(latent_cond)
Expand All @@ -1075,6 +1079,7 @@ def forward(

# ── ODE function ──────────────────────────────────────────────
def fn(t, x):
x = x.clone()
x[:, :latent_len] = prompt_noise * (1-t) + latent_cond[:, :latent_len] * t
output = self.transformer(
x=x, text=text_condition, text_len=text_condition_len, time=t,
Expand Down Expand Up @@ -1120,12 +1125,11 @@ def fn(t, x):
# ── ODE solve ─────────────────────────────────────────────────
t = torch.linspace(0, 1, steps, device=device)
prompt_noise = y0[:, :latent_len].clone()
trajectory = odeint_euler(fn, y0, t)
sampled = trajectory[-1]
sampled = odeint_euler(fn, y0, t)

# ── decode ────────────────────────────────────────────────────
pred_latent = sampled
if prompt_audio is not None:
if prompt_dur > 0:
pred_latent = pred_latent[:, prompt_dur:]

pred_latent = pred_latent.permute(0, 2, 1).float()
Expand Down Expand Up @@ -1153,14 +1157,14 @@ def update(self, update_value: torch.Tensor):

def _project(v0: torch.Tensor, v1: torch.Tensor, dims=(-1, -2)):
dtype = v0.dtype
device_type = v0.device.type
if device_type == "mps":
orig_device = v0.device
if orig_device.type == "mps":
v0, v1 = v0.cpu(), v1.cpu()
v0, v1 = v0.double(), v1.double()
v1 = F.normalize(v1, dim=dims)
v0_parallel = (v0 * v1).sum(dim=dims, keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
return v0_parallel.to(dtype).to(device_type), v0_orthogonal.to(dtype).to(device_type)
return v0_parallel.to(dtype=dtype, device=orig_device), v0_orthogonal.to(dtype=dtype, device=orig_device)


def _apg_forward(pred_cond, pred_uncond, guidance_scale, momentum_buffer=None, eta=0.0, norm_threshold=2.5, dims=(-1, -2)):
Expand Down
16 changes: 4 additions & 12 deletions batch_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import numpy as np
import soundfile as sf
import torch
import torch.nn.functional as F

import audiodit
from audiodit import AudioDiTModel
Expand All @@ -41,16 +40,8 @@ def infer_one(gen_text, prompt_text, prompt_wav_path, model, tokenizer, device,
inputs = tokenizer([full_text], padding="longest", return_tensors="pt")
prompt_wav = load_audio(prompt_wav_path, sr).unsqueeze(0)

# Duration estimation
off = 3
pw = load_audio(prompt_wav_path, sr)
if pw.shape[-1] % full_hop != 0:
pw = F.pad(pw, (0, full_hop - pw.shape[-1] % full_hop))
pw = F.pad(pw, (0, full_hop * off))
plt = model.vae.encode(pw.unsqueeze(0).to(device))
if off:
plt = plt[..., :-off]
prompt_dur = plt.shape[-1]
# Encode prompt audio once (reused for duration estimation and generation)
prompt_latent, prompt_dur = model.encode_prompt_audio(prompt_wav.to(device))

prompt_time = prompt_dur * full_hop / sr
dur_sec = approx_duration_from_text(gen_text, max_duration - prompt_time)
Expand All @@ -63,7 +54,8 @@ def infer_one(gen_text, prompt_text, prompt_wav_path, model, tokenizer, device,
output = model(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
prompt_audio=prompt_wav,
prompt_latent=prompt_latent,
prompt_duration_frames=prompt_dur,
duration=duration,
steps=nfe,
cfg_strength=cfg_strength,
Expand Down
17 changes: 5 additions & 12 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import numpy as np
import soundfile as sf
import torch
import torch.nn.functional as F

import audiodit # auto-registers AudioDiTConfig/AudioDiTModel
from audiodit import AudioDiTModel
Expand Down Expand Up @@ -74,19 +73,12 @@ def main():
if not no_prompt:
prompt_wav = load_audio(args.prompt_audio, sr).unsqueeze(0)

# Compute prompt duration for time estimation
off = 3
pw = load_audio(args.prompt_audio, sr)
if pw.shape[-1] % full_hop != 0:
pw = F.pad(pw, (0, full_hop - pw.shape[-1] % full_hop))
pw = F.pad(pw, (0, full_hop * off))
# Encode prompt audio once (reused for duration estimation and generation)
with torch.no_grad():
plt = model.vae.encode(pw.unsqueeze(0).to(device))
if off:
plt = plt[..., :-off]
prompt_dur = plt.shape[-1]
prompt_latent, prompt_dur = model.encode_prompt_audio(prompt_wav.to(device))
else:
prompt_wav = None
prompt_latent = None
prompt_dur = 0

# Duration estimation
Expand All @@ -104,7 +96,8 @@ def main():
output = model(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
prompt_audio=prompt_wav,
prompt_latent=prompt_latent,
prompt_duration_frames=prompt_dur if prompt_latent is not None else None,
duration=duration,
steps=args.nfe,
cfg_strength=args.guidance_strength,
Expand Down