Skip to content

output_attentions=True forces eager attention, blocking SDPA/flash attention for English models (fix unreleased on master) #525

@cuttle-agent

Description

@cuttle-agent

output_attentions=True forces eager attention for all models, blocking SDPA/flash attention optimization (unreleased fix on master)

Bug Report

Summary

In the latest PyPI release (v0.1.7 / tag v0.1.2), t3.py hardcodes output_attentions=True in both the initial forward pass and the generation loop. This forces PyTorch's eager attention implementation, disabling SDPA and flash_attention optimizations.

For English models, the AlignmentStreamAnalyzer that consumes these attention weights is None — so the attention outputs are computed but never used. This wastes GPU memory bandwidth and compute on every forward pass during autoregressive generation.

The fix is already on master (both calls changed to output_attentions=False) but has never been released, leaving all pip users affected.

Impact

  • Every Chatterbox user installing via pip install chatterbox-tts (which installs v0.1.7)
  • English models compute full attention weights that are never consumed — pure waste
  • Blocks SDPA/flash attention for T3 transformer, which is the most expensive component
  • On ROCm/GPU: forces eager attention mode, preventing AOTriton flash attention from working
  • On CUDA: prevents torch.nn.functional.scaled_dot_product_attention from using Flash Attention 2 or memory-efficient attention kernels
  • Compounds with issue SDPA Compatibility Error: output_attentions not supported when using voice references with transformers >=4.36 #339 (SDPA crash with output_attentions=True) — the root cause is the same hardcoding

Reproduction

# Install latest release
# pip install chatterbox-tts==0.1.7

from chatterbox.models.t3.t3 import T3

# In t3.py, lines ~311 and ~362 (v0.1.2 tag):
#   output_attentions=True,   # <-- hardcoded, forces eager attention
#   ...
#   output_attentions=True,   # <-- same in generation loop

# With transformers >=4.36 (SDPA default), this causes:
#   ValueError: The `output_attentions` attribute is not supported
#   when using `attn_implementation` set to sdpa.
# (Issue #339)

# With ROCm/AOTriton flash attention built from source:
#   Forces fallback to eager mode, losing flash attention speedup

# For English models, alignment_stream_analyzer is None, so these
# attention outputs are never consumed:
model = ChatterboxTTS.from_pretrained(device="cuda")
# model.alignment_stream_analyzer is None for English
# → attention weights computed and discarded every step

Root Cause

In src/chatterbox/models/t3/t3.py (v0.1.2 tag), lines ~311 and ~362:

# Initial forward pass
output = self.patched_model(
    inputs_embeds=inputs_embeds,
    past_key_values=None,
    use_cache=True,
    output_attentions=True,  # ← always True
    output_hidden_states=True,
    return_dict=True,
)

# Generation loop
output = self.patched_model(
    inputs_embeds=next_token_embed,
    past_key_values=past,
    output_attentions=True,  # ← always True
    output_hidden_states=True,
    return_dict=True,
)

The AlignmentStreamAnalyzer that needs these weights is only instantiated for multilingual models. For English models, self.patched_model.alignment_stream_analyzer is None — but output_attentions=True is unconditional.

Fix (already on master)

The master branch already has this fixed:

output_attentions=False,

at both locations. However, this change has not been released — pip still installs the broken version.

Suggested Fix

If a new release isn't imminent, the conditional fix is more precise:

output_attentions=self.patched_model.alignment_stream_analyzer is not None,

This preserves multilingual alignment functionality while enabling SDPA/flash attention for English models.

Why This Matters for Performance

The T3 transformer is memory-bound during autoregressive generation — each step is a full transformer forward pass. Computing unused attention weights:

  1. Allocates (batch, heads, seq_len, seq_len) tensors per layer every step
  2. Forces the eager attention kernel instead of memory-efficient SDPA/flash
  3. The attention outputs are stored in the BaseModelOutputWithPast object but never accessed

On my testing (Radeon 8060S, gfx1151, ROCm 6.4.3):

  • With output_attentions=True: T3 runs at ~24 it/s (memory-bound ceiling)
  • With output_attentions=False: same it/s ceiling (still memory-bound), but enables SDPA/flash attention path for future optimization and reduces per-step memory allocation

On CUDA GPUs with flash attention available, the impact is larger — eager attention is significantly slower than flash for longer sequences.

Related Issues

Environment

  • chatterbox-tts v0.1.7 (pip, tag v0.1.2)
  • PyTorch 2.8.0a0 (built from source for ROCm gfx1151)
  • ROCm 6.4.3
  • Ubuntu 26.04

cc @resemble-ai — could a patch release be cut with this fix? It's already on master.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions