Skip to content

Add Tesla T4 compatibility for the 0.5B dense model (fp32 fallback + SDPA fallback) #12

Description

@akshatvishu

This issue is specifically about the 0.5B dense model. While the MoE variants may still require bf16-capable hardware, the 0.5B dense model is small enough that supporting Tesla T4-class GPUs seems both practical and useful. T4 is widely available through platforms such as Colab and Kaggle, so enabling stable inference there would improve accessibility for many researchers and developers.

Proposed approach

Runtime load dtype selection

self.llm_dtype = (
    torch.bfloat16
    if torch.cuda.is_bf16_supported(including_emulation=False)
    else torch.float16
)

fp32 promotion during generation on fp16 hardware

On T4, running the audio pipeline in fp16 produces NaNs.
Upcasting the full model to fp32 before generation avoids this issue:

if next(self.model.parameters()).dtype == torch.float16:
    self.float()

flash-attn fallback

Add a fallback to PyTorch native SDPA when flash-attn is unavailable/not supported for a hardware.

Validation:

I have tested my above assumption and it seems to be working on kaggle T4(with all the examples from cookbook.ipynb + test.py) , you can check the notebook at here.

The Draft PR is : here

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions