Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 184 additions & 16 deletions candle-examples/examples/whisper/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ struct Decoder {
rng: rand::rngs::StdRng,
task: Option<Task>,
timestamps: bool,
max_initial_timestamp_index: Option<u32>,
verbose: bool,
tokenizer: Tokenizer,
suppress_tokens: Tensor,
Expand All @@ -110,6 +111,7 @@ impl Decoder {
language_token: Option<u32>,
task: Option<Task>,
timestamps: bool,
max_initial_timestamp_index: Option<u32>,
verbose: bool,
) -> Result<Self> {
let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?;
Expand Down Expand Up @@ -144,6 +146,7 @@ impl Decoder {
tokenizer,
task,
timestamps,
max_initial_timestamp_index,
verbose,
suppress_tokens,
sot_token,
Expand All @@ -157,12 +160,11 @@ impl Decoder {
}

fn decode(&mut self, mel: &Tensor, t: f64) -> Result<DecodingResult> {
let model = &mut self.model;
let audio_features = model.encoder_forward(mel, true)?;
let audio_features = self.model.encoder_forward(mel, true)?;
if self.verbose {
println!("audio features: {:?}", audio_features.dims());
}
let sample_len = model.config().max_target_positions / 2;
let sample_len = self.model.config().max_target_positions / 2;
let mut sum_logprob = 0f64;
let mut no_speech_prob = f64::NAN;
let mut tokens = vec![self.sot_token];
Expand All @@ -182,29 +184,33 @@ impl Decoder {
// The model expects a batch dim but this inference loop does not handle
// it so we add it at this point.
let tokens_t = tokens_t.unsqueeze(0)?;
let ys = model.decoder_forward(&tokens_t, &audio_features, i == 0)?;
let ys = self
.model
.decoder_forward(&tokens_t, &audio_features, i == 0)?;

// Extract the no speech probability on the first iteration by looking at the first
// token logits and the probability for the according token.
if i == 0 {
let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
let logits = self.model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
no_speech_prob = softmax(&logits, 0)?
.i(self.no_speech_token as usize)?
.to_scalar::<f32>()? as f64;
}

let (_, seq_len, _) = ys.dims3()?;
let logits = model
let logits = self
.model
.decoder_final_linear(&ys.i((..1, seq_len - 1..))?)?
.i(0)?
.i(0)?;
// TODO: Besides suppress tokens, we should apply the heuristics from
// ApplyTimestampRules, i.e.:
// - Timestamps come in pairs, except before EOT.
// - Timestamps should be non-decreasing.
// - If the sum of the probabilities of timestamps is higher than any other tokens,
// only consider timestamps when sampling.
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L439

// Apply timestamp rules when timestamps are enabled
let logits = if self.timestamps {
self.apply_timestamp_rules(&logits, &tokens)?
} else {
logits
};

let logits = logits.broadcast_add(&self.suppress_tokens)?;
let next_token = if t > 0f64 {
let prs = softmax(&(&logits / t)?, 0)?;
Expand All @@ -224,7 +230,9 @@ impl Decoder {
let prob = softmax(&logits, candle::D::Minus1)?
.i(next_token as usize)?
.to_scalar::<f32>()? as f64;
if next_token == self.eot_token || tokens.len() > model.config().max_target_positions {
if next_token == self.eot_token
|| tokens.len() > self.model.config().max_target_positions
{
break;
}
sum_logprob += prob.ln();
Expand Down Expand Up @@ -265,6 +273,161 @@ impl Decoder {
unreachable!()
}

fn apply_timestamp_rules(&self, input_logits: &Tensor, tokens: &[u32]) -> Result<Tensor> {
let device = input_logits.device().clone();
let timestamp_begin = self.no_timestamps_token + 1;
let vocab_size = self.model.config().vocab_size as u32;

// ========== SETUP: Extract sampled tokens for analysis ==========
let sample_begin = if self.language_token.is_some() { 3 } else { 2 };
let sampled_tokens = if tokens.len() > sample_begin {
&tokens[sample_begin..]
} else {
&[]
};

let mut masks = Vec::new();

// ========== RULE 1: Timestamp pairing constraints ==========
// Timestamps must come in pairs, except directly before EOT
if !sampled_tokens.is_empty() {
let last_was_timestamp = sampled_tokens
.last()
.map(|&t| t >= timestamp_begin)
.unwrap_or(false);

let penultimate_was_timestamp = if sampled_tokens.len() >= 2 {
sampled_tokens[sampled_tokens.len() - 2] >= timestamp_begin
} else {
false
};

if last_was_timestamp {
if penultimate_was_timestamp {
// Has to be non-timestamp - suppress timestamp tokens
let mask = (0..vocab_size)
.map(|i| {
if i >= timestamp_begin {
f32::NEG_INFINITY
} else {
0.0
}
})
.collect::<Vec<f32>>();
masks.push(Tensor::new(mask.as_slice(), &device)?);
} else {
// Cannot be normal text tokens - suppress everything before EOT
let mask = (0..vocab_size)
.map(|i| {
if i < self.eot_token {
f32::NEG_INFINITY
} else {
0.0
}
})
.collect::<Vec<f32>>();
masks.push(Tensor::new(mask.as_slice(), &device)?);
}
}

// ========== RULE 2: Non-decreasing timestamp constraint ==========
// Timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
let timestamp_tokens: Vec<u32> = sampled_tokens
.iter()
.filter(|&&t| t >= timestamp_begin)
.cloned()
.collect();

if !timestamp_tokens.is_empty() {
let timestamp_last = if last_was_timestamp && !penultimate_was_timestamp {
*timestamp_tokens.last().unwrap()
} else {
timestamp_tokens.last().unwrap() + 1
Comment thread
rsb-tbg marked this conversation as resolved.
};

let mask = (0..vocab_size)
.map(|i| {
if i >= timestamp_begin && i < timestamp_last {
f32::NEG_INFINITY
} else {
0.0
}
})
.collect::<Vec<f32>>();
masks.push(Tensor::new(mask.as_slice(), &device)?);
}
}

// ========== RULE 3: Force initial timestamp ==========
// At the beginning, suppress generating non-timestamp tokens
if tokens.len() == sample_begin {
let mask = (0..vocab_size)
.map(|i| {
if i < timestamp_begin {
f32::NEG_INFINITY
} else {
0.0
}
})
.collect::<Vec<f32>>();
masks.push(Tensor::new(mask.as_slice(), &device)?);

// Apply the max_initial_timestamp constraint
if let Some(max_initial_timestamp_index) = self.max_initial_timestamp_index {
let last_allowed = timestamp_begin + max_initial_timestamp_index;
if last_allowed < vocab_size {
let mask = (0..vocab_size)
.map(|i| {
if i > last_allowed {
f32::NEG_INFINITY
} else {
0.0
}
})
.collect::<Vec<f32>>();
masks.push(Tensor::new(mask.as_slice(), &device)?);
}
}
}

// ========== APPLY MASKS: Apply all constraint masks ==========
let mut logits = input_logits.clone();
for mask in masks {
logits = logits.broadcast_add(&mask)?;
}

// ========== RULE 4: Probability-based timestamp preference ==========
// If sum of probability over timestamps is above any other token, sample timestamp
let log_probs = softmax(&logits, 0)?;
let log_probs_vec: Vec<f32> = log_probs.to_vec1()?;

// Sum probabilities for timestamp tokens
let timestamp_prob_sum: f32 = log_probs_vec[timestamp_begin as usize..].iter().sum();
Comment thread
rsb-tbg marked this conversation as resolved.
Outdated

// Find max probability for non-timestamp tokens
let max_text_prob = log_probs_vec[..timestamp_begin as usize]
Comment thread
rsb-tbg marked this conversation as resolved.
Outdated
.iter()
.cloned()
.fold(0.0f32, f32::max);

if timestamp_prob_sum > max_text_prob {
// Only consider timestamp tokens
let mask = (0..vocab_size)
.map(|i| {
if i < timestamp_begin {
f32::NEG_INFINITY
} else {
0.0
}
})
.collect::<Vec<f32>>();
let mask_tensor = Tensor::new(mask.as_slice(), &device)?;
logits = logits.broadcast_add(&mask_tensor)?;
}

Ok(logits)
}

fn run(&mut self, mel: &Tensor) -> Result<Vec<Segment>> {
let (_, _, content_frames) = mel.dims3()?;
let mut seek = 0;
Expand Down Expand Up @@ -465,10 +628,14 @@ struct Args {
#[arg(long)]
task: Option<Task>,

/// Timestamps mode, this is not fully implemented yet.
#[arg(long)]
/// Timestamps mode.
#[arg(long, default_value_t = true)]
timestamps: bool,

/// Maximum initial timestamp index to consider.
#[arg(long)]
max_initial_timestamp_index: Option<u32>,

/// Print the full DecodingResult structure rather than just the text.
#[arg(long)]
verbose: bool,
Expand Down Expand Up @@ -590,6 +757,7 @@ fn main() -> Result<()> {
language_token,
args.task,
args.timestamps,
args.max_initial_timestamp_index,
args.verbose,
)?;
dc.run(&mel)?;
Expand Down