Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions candle-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@ image = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
palette = { version = "0.7.6", optional = true }
enterpolation = { version = "0.2.1", optional = true}
pyo3 = { version = "0.22.0", features = ["auto-initialize", "abi3-py311"], optional = true }
enterpolation = { version = "0.2.1", optional = true }
pyo3 = { version = "0.22.0", features = [
"auto-initialize",
"abi3-py311",
], optional = true }
rayon = { workspace = true }
rubato = { version = "0.15.0", optional = true }
safetensors = { workspace = true }
Expand All @@ -36,7 +39,8 @@ serde_json = { workspace = true }
symphonia = { version = "0.5.3", features = ["all"], optional = true }
tokenizers = { workspace = true, features = ["onig"] }
cpal = { version = "0.15.2", optional = true }
pdf2image = { version = "0.1.2" , optional = true}
pdf2image = { version = "0.1.2", optional = true }
tekken-rs = { version = "0.1.1", optional = true }

[dev-dependencies]
anyhow = { workspace = true }
Expand All @@ -58,11 +62,26 @@ bindgen_cuda = { version = "0.1.1", optional = true }

[features]
default = []
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"]
accelerate = [
"dep:accelerate-src",
"candle/accelerate",
"candle-nn/accelerate",
"candle-transformers/accelerate",
]
cuda = [
"candle/cuda",
"candle-nn/cuda",
"candle-transformers/cuda",
"dep:bindgen_cuda",
]
cudnn = ["candle/cudnn", "candle-nn/cudnn", "candle-transformers/cudnn"]
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
mkl = [
"dep:intel-mkl-src",
"candle/mkl",
"candle-nn/mkl",
"candle-transformers/mkl",
]
nccl = ["cuda", "cudarc/nccl", "dep:half"]
onnx = ["candle-onnx"]
metal = ["candle/metal", "candle-nn/metal"]
Expand All @@ -71,6 +90,7 @@ encodec = ["cpal", "symphonia", "rubato"]
mimi = ["cpal", "symphonia", "rubato"]
snac = ["cpal", "symphonia", "rubato"]
depth_anything_v2 = ["palette", "enterpolation"]
tekken = ["tekken-rs"]

[[example]]
name = "llama_multiprocess"
Expand Down Expand Up @@ -131,3 +151,7 @@ required-features = ["onnx"]
[[example]]
name = "colpali"
required-features = ["pdf2image"]

[[example]]
name = "voxtral"
required-features = ["symphonia"]
29 changes: 0 additions & 29 deletions candle-examples/examples/snac/audio_io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,32 +244,3 @@ pub(crate) fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> Result<(Vec<f32>
}
Ok((pcm_data, sample_rate))
}

pub(crate) fn resample(pcm_in: &[f32], sr_in: u32, sr_out: u32) -> Result<Vec<f32>> {
use rubato::Resampler;

let mut pcm_out =
Vec::with_capacity((pcm_in.len() as f64 * sr_out as f64 / sr_in as f64) as usize + 1024);

let mut resampler =
rubato::FftFixedInOut::<f32>::new(sr_in as usize, sr_out as usize, 1024, 1)?;
let mut output_buffer = resampler.output_buffer_allocate(true);
let mut pos_in = 0;
while pos_in + resampler.input_frames_next() < pcm_in.len() {
let (in_len, out_len) =
resampler.process_into_buffer(&[&pcm_in[pos_in..]], &mut output_buffer, None)?;
pos_in += in_len;
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
}

if pos_in < pcm_in.len() {
let (_in_len, out_len) = resampler.process_partial_into_buffer(
Some(&[&pcm_in[pos_in..]]),
&mut output_buffer,
None,
)?;
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
}

Ok(pcm_out)
}
2 changes: 1 addition & 1 deletion candle-examples/examples/snac/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ fn main() -> Result<()> {
let (pcm, sample_rate) = audio_io::pcm_decode(args.in_file)?;
if sample_rate != model_sample_rate {
println!("WARNING: snac uses a {model_sample_rate} sample rate, input uses {sample_rate}, resampling...");
audio_io::resample(&pcm, sample_rate, model_sample_rate)?
candle_examples::audio::resample(&pcm, sample_rate, model_sample_rate)?
} else {
pcm
}
Expand Down
25 changes: 25 additions & 0 deletions candle-examples/examples/voxtral/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# candle-voxtral: speech recognition

An implementation of Voxtral speech recognition using candle.

## Running the example

Run with the `cuda` feature for GPU acceleration:
```bash
cargo run --example voxtral --features tekken,symphonia,rubato,cuda --release
# you may also add the `cudnn` feature for extra performance
# cargo run --example voxtral --features tekken,symphonia,rubato,cuda,cudnn --release
```

Remove the `cuda` feature to run on the CPU instead:
```bash
cargo run --example voxtral --features tekken,symphonia,rubato --release
# or pass the `--cpu` flag to force CPU usage
# cargo run --example voxtral --features tekken,symphonia,rubato,cuda --release -- --cpu
```

## Command line options

- `--cpu`: Run on CPU rather than on GPU (default: false, uses GPU if available)
- `--input`: Audio file path in wav format. If not provided, a sample file is automatically downloaded from the hub.
- `--model-id`: Model to use (default: `mistralai/Voxtral-Mini-3B-2507`)
75 changes: 75 additions & 0 deletions candle-examples/examples/voxtral/download.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
use std::path::PathBuf;

use anyhow::Result;
use hf_hub::{api::sync::Api, Repo, RepoType};

/// # Errors
///
/// Returns an error if the model files cannot be downloaded.
///
/// # Panics
///
/// Panics if the model files cannot be downloaded.
pub fn model_files(model_id: &str) -> Result<((PathBuf, Vec<PathBuf>), PathBuf)> {
let revision = "main";

let api = Api::new().unwrap();
let repo = api.repo(Repo::with_revision(
model_id.to_string(),
RepoType::Model,
revision.to_string(),
));

let config = repo.get("config.json")?;

// Download model files - look for safetensors
let mut model_files = Vec::new();

// Common Voxtral/Ultravox safetensors file patterns
let safetensors_files = match model_id {
"mistralai/Voxtral-Mini-3B-2507" => vec![
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
],
"mistralai/Voxtral-Small-24B-2507" => vec![
"model-00001-of-00011.safetensors",
"model-00001-of-00011.safetensors",
"model-00002-of-00011.safetensors",
"model-00003-of-00011.safetensors",
"model-00004-of-00011.safetensors",
"model-00005-of-00011.safetensors",
"model-00006-of-00011.safetensors",
"model-00007-of-00011.safetensors",
"model-00008-of-00011.safetensors",
"model-00009-of-00011.safetensors",
"model-00010-of-00011.safetensors",
"model-00011-of-00011.safetensors",
],
_ => vec![
"model.safetensors",
"pytorch_model.safetensors",
"model-00001-of-00001.safetensors",
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
],
};

println!("Downloading safetensors files...");
for filename in &safetensors_files {
if let Ok(file) = repo.get(filename) {
println!("{} downloaded", filename);
model_files.push(file);
}
}

if model_files.is_empty() {
anyhow::bail!("No safetensors files found in model repository {model_id}",);
}

// Download tokenizer
let tokenizer_file = repo
.get("tekken.json")
.or_else(|_| repo.get("tokenizer/tokenizer.json"))?;

Ok(((config, model_files), tokenizer_file))
}
75 changes: 75 additions & 0 deletions candle-examples/examples/voxtral/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
use anyhow::{Context, Result};
use clap::Parser;
use hf_hub::api::sync::Api;
use model::VoxtralModel;

mod download;
mod model;

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long, default_value_t = false)]
cpu: bool,

/// The input to be processed, in wav format, will default to `jfk.wav`. Alternatively
/// this can be set to sample:jfk, sample:gb1, ... to fetch a sample from the following
/// repo: https://huggingface.co/datasets/Narsil/candle_demo/
#[arg(long)]
input: Option<String>,

#[arg(long, default_value = "mistralai/Voxtral-Mini-3B-2507")]
model_id: Option<String>,
}

#[cfg(feature = "cuda")]
fn use_cpu() -> bool {
true
}

#[cfg(not(feature = "cuda"))]
fn use_cpu() -> bool {
false
}

fn main() -> Result<()> {
let args = Args::parse();

let use_cpu = args.cpu || !use_cpu();

let model_id = args.model_id.unwrap();

// Create model - equivalent to loading the model and processor in Python
let mut model =
VoxtralModel::new(&model_id, use_cpu).context("Failed to load Voxtral model")?;

println!("Model loaded successfully on device: {:?}", model.device());

let api = Api::new()?;
let dataset = api.dataset("Narsil/candle-examples".to_string());

let audio_file = if let Some(input) = args.input {
if let Some(sample) = input.strip_prefix("sample:") {
dataset.get(&format!("samples_{sample}.wav"))?
} else {
std::path::PathBuf::from(input)
}
} else {
println!("No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav");
dataset.get("samples_jfk.wav")?
};

let (audio_data, sample_rate) =
candle_examples::audio::pcm_decode(audio_file).context("Failed to decode audio file")?;

// Transcribe audio with token output
let result = model
.transcribe_audio(&audio_data, sample_rate)
.context("Failed to transcribe audio with tokens")?;

println!("\n===================================================\n");
println!("{}", result.text);

Ok(())
}
Binary file not shown.
Loading
Loading