Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 191 additions & 17 deletions candle-examples/examples/whisper/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ extern crate intel_mkl_src;

use anyhow::{Error as E, Result};
use candle::{Device, IndexOp, Tensor};
use candle_nn::{ops::softmax, VarBuilder};
use candle_nn::{
ops::{log_softmax, softmax},
VarBuilder,
};
use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
use rand::distr::weighted::WeightedIndex;
Expand Down Expand Up @@ -88,6 +91,7 @@ struct Decoder {
rng: rand::rngs::StdRng,
task: Option<Task>,
timestamps: bool,
max_initial_timestamp_index: Option<u32>,
verbose: bool,
tokenizer: Tokenizer,
suppress_tokens: Tensor,
Expand All @@ -110,6 +114,7 @@ impl Decoder {
language_token: Option<u32>,
task: Option<Task>,
timestamps: bool,
max_initial_timestamp_index: Option<u32>,
verbose: bool,
) -> Result<Self> {
let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?;
Expand Down Expand Up @@ -144,6 +149,7 @@ impl Decoder {
tokenizer,
task,
timestamps,
max_initial_timestamp_index,
verbose,
suppress_tokens,
sot_token,
Expand All @@ -157,12 +163,11 @@ impl Decoder {
}

fn decode(&mut self, mel: &Tensor, t: f64) -> Result<DecodingResult> {
let model = &mut self.model;
let audio_features = model.encoder_forward(mel, true)?;
let audio_features = self.model.encoder_forward(mel, true)?;
if self.verbose {
println!("audio features: {:?}", audio_features.dims());
}
let sample_len = model.config().max_target_positions / 2;
let sample_len = self.model.config().max_target_positions / 2;
let mut sum_logprob = 0f64;
let mut no_speech_prob = f64::NAN;
let mut tokens = vec![self.sot_token];
Expand All @@ -182,29 +187,33 @@ impl Decoder {
// The model expects a batch dim but this inference loop does not handle
// it so we add it at this point.
let tokens_t = tokens_t.unsqueeze(0)?;
let ys = model.decoder_forward(&tokens_t, &audio_features, i == 0)?;
let ys = self
.model
.decoder_forward(&tokens_t, &audio_features, i == 0)?;

// Extract the no speech probability on the first iteration by looking at the first
// token logits and the probability for the according token.
if i == 0 {
let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
let logits = self.model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
no_speech_prob = softmax(&logits, 0)?
.i(self.no_speech_token as usize)?
.to_scalar::<f32>()? as f64;
}

let (_, seq_len, _) = ys.dims3()?;
let logits = model
let logits = self
.model
.decoder_final_linear(&ys.i((..1, seq_len - 1..))?)?
.i(0)?
.i(0)?;
// TODO: Besides suppress tokens, we should apply the heuristics from
// ApplyTimestampRules, i.e.:
// - Timestamps come in pairs, except before EOT.
// - Timestamps should be non-decreasing.
// - If the sum of the probabilities of timestamps is higher than any other tokens,
// only consider timestamps when sampling.
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L439

// Apply timestamp rules when timestamps are enabled
let logits = if self.timestamps {
self.apply_timestamp_rules(&logits, &tokens)?
} else {
logits
};

let logits = logits.broadcast_add(&self.suppress_tokens)?;
let next_token = if t > 0f64 {
let prs = softmax(&(&logits / t)?, 0)?;
Expand All @@ -224,7 +233,9 @@ impl Decoder {
let prob = softmax(&logits, candle::D::Minus1)?
.i(next_token as usize)?
.to_scalar::<f32>()? as f64;
if next_token == self.eot_token || tokens.len() > model.config().max_target_positions {
if next_token == self.eot_token
|| tokens.len() > self.model.config().max_target_positions
{
break;
}
sum_logprob += prob.ln();
Expand Down Expand Up @@ -265,6 +276,164 @@ impl Decoder {
unreachable!()
}

fn apply_timestamp_rules(&self, input_logits: &Tensor, tokens: &[u32]) -> Result<Tensor> {
let device = input_logits.device().clone();
let timestamp_begin = self.no_timestamps_token + 1;
let vocab_size = self.model.config().vocab_size as u32;

// ========== SETUP: Extract sampled tokens for analysis ==========
let sample_begin = if self.language_token.is_some() { 3 } else { 2 };
let sampled_tokens = if tokens.len() > sample_begin {
&tokens[sample_begin..]
} else {
&[]
};

let mut masks = Vec::new();
// Pre-allocate reusable mask buffer to avoid repeated allocations
let mut mask_buffer = vec![0.0f32; vocab_size as usize];

// ========== RULE 1: Timestamp pairing constraints ==========
// Timestamps must come in pairs, except directly before EOT
if !sampled_tokens.is_empty() {
let last_was_timestamp = sampled_tokens
.last()
.map(|&t| t >= timestamp_begin)
.unwrap_or(false);

let penultimate_was_timestamp = if sampled_tokens.len() >= 2 {
sampled_tokens[sampled_tokens.len() - 2] >= timestamp_begin
} else {
false
};

if last_was_timestamp {
if penultimate_was_timestamp {
// Has to be non-timestamp - suppress timestamp tokens
for i in 0..vocab_size {
mask_buffer[i as usize] = if i >= timestamp_begin {
f32::NEG_INFINITY
} else {
0.0
};
}
masks.push(Tensor::new(mask_buffer.as_slice(), &device)?);
} else {
// Cannot be normal text tokens - suppress everything before EOT
for i in 0..vocab_size {
mask_buffer[i as usize] = if i < self.eot_token {
f32::NEG_INFINITY
} else {
0.0
};
}
masks.push(Tensor::new(mask_buffer.as_slice(), &device)?);
}
}

// ========== RULE 2: Non-decreasing timestamp constraint ==========
// Timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
let timestamp_tokens: Vec<u32> = sampled_tokens
.iter()
.filter(|&&t| t >= timestamp_begin)
.cloned()
.collect();

if !timestamp_tokens.is_empty() {
let timestamp_last = if last_was_timestamp && !penultimate_was_timestamp {
*timestamp_tokens.last().unwrap()
} else {
timestamp_tokens.last().unwrap() + 1
Comment thread
rsb-tbg marked this conversation as resolved.
};

for i in 0..vocab_size {
mask_buffer[i as usize] = if i >= timestamp_begin && i < timestamp_last {
f32::NEG_INFINITY
} else {
0.0
};
}
masks.push(Tensor::new(mask_buffer.as_slice(), &device)?);
}
}

// ========== RULE 3: Force initial timestamp ==========
// At the beginning, suppress generating non-timestamp tokens
if tokens.len() == sample_begin {
for i in 0..vocab_size {
mask_buffer[i as usize] = if i < timestamp_begin {
f32::NEG_INFINITY
} else {
0.0
};
}
masks.push(Tensor::new(mask_buffer.as_slice(), &device)?);

// Apply the max_initial_timestamp constraint
if let Some(max_initial_timestamp_index) = self.max_initial_timestamp_index {
let last_allowed = timestamp_begin + max_initial_timestamp_index;
if last_allowed < vocab_size {
for i in 0..vocab_size {
mask_buffer[i as usize] = if i > last_allowed {
f32::NEG_INFINITY
} else {
0.0
};
}
masks.push(Tensor::new(mask_buffer.as_slice(), &device)?);
}
}
}

// ========== APPLY MASKS: Apply all constraint masks ==========
let mut logits = input_logits.clone();
for mask in masks {
logits = logits.broadcast_add(&mask)?;
}

// ========== RULE 4: Probability-based timestamp preference ==========
// If sum of probability over timestamps is above any other token, sample timestamp
let log_probs = log_softmax(&logits, 0)?;

// Extract timestamp and text log probabilities
let timestamp_log_probs = log_probs.narrow(
0,
timestamp_begin as usize,
vocab_size as usize - timestamp_begin as usize,
)?;

let text_log_probs = log_probs.narrow(0, 0, timestamp_begin as usize)?;

// Implement logsumexp for timestamp tokens (numerically stable)
let timestamp_logprob = {
let max_val = timestamp_log_probs.max(0)?;
let shifted = timestamp_log_probs.broadcast_sub(&max_val)?;
let exp_shifted = shifted.exp()?;
let sum_exp = exp_shifted.sum(0)?;
let log_sum = sum_exp.log()?;
max_val.broadcast_add(&log_sum)?.to_scalar::<f32>()?
};

// Get max text token log probability
let max_text_token_logprob: f32 = text_log_probs.max(0)?.to_scalar::<f32>()?;

// Compare in log space
if timestamp_logprob > max_text_token_logprob {
// Only consider timestamp tokens
for i in 0..vocab_size {
mask_buffer[i as usize] = if i < timestamp_begin {
f32::NEG_INFINITY
} else {
0.0
};
}
let mask_tensor = Tensor::new(mask_buffer.as_slice(), &device)?;
logits = logits.broadcast_add(&mask_tensor)?;
}

Ok(logits)
}

fn run(&mut self, mel: &Tensor) -> Result<Vec<Segment>> {
let (_, _, content_frames) = mel.dims3()?;
let mut seek = 0;
Expand Down Expand Up @@ -465,10 +634,14 @@ struct Args {
#[arg(long)]
task: Option<Task>,

/// Timestamps mode, this is not fully implemented yet.
#[arg(long)]
/// Timestamps mode.
#[arg(long, default_value_t = true)]
timestamps: bool,

/// Maximum initial timestamp index to consider.
#[arg(long)]
max_initial_timestamp_index: Option<u32>,

/// Print the full DecodingResult structure rather than just the text.
#[arg(long)]
verbose: bool,
Expand Down Expand Up @@ -590,6 +763,7 @@ fn main() -> Result<()> {
language_token,
args.task,
args.timestamps,
args.max_initial_timestamp_index,
args.verbose,
)?;
dc.run(&mel)?;
Expand Down
Loading