Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use crate::models::{
use crate::models::{
FlashBertModel, FlashDistilBertModel, FlashGTEModel, FlashJinaBertModel,
FlashJinaCodeBertModel, FlashMistralModel, FlashModernBertModel, FlashNomicBertModel,
FlashQwen2Model, FlashQwen3Model,
FlashPplx1Model, FlashQwen2Model, FlashQwen3Model,
};

#[derive(Debug, Clone, Copy, PartialEq)]
Expand Down Expand Up @@ -366,9 +366,9 @@ impl CandleBackend {
}
(Config::Pplx1(config), Device::Cpu | Device::Metal(_)) => {
// TODO(alvarobartt): Enable Flash Attention with BF16 once supported on Metal
if dtype != DType::F32 {
if dtype == DType::BF16 {
Err(BackendError::Start(
"Pplx1 is only supported in fp32 precision".to_string(),
"Pplx1 does not support in bf16 precision".to_string(),
))
} else {
tracing::info!("Starting Pplx1 model on {:?}", device);
Expand Down Expand Up @@ -536,12 +536,13 @@ impl CandleBackend {
#[cfg(feature = "cuda")]
(Config::Pplx1(config), Device::Cuda(_)) => {
// TODO(alvarobartt): Enable Flash Attention with BF16 once supported on CUDA
if dtype != DType::F32 {
Err(BackendError::Start(
"Pplx1 is only supported in fp32 precision".to_string(),
if dtype == DType::F16 && use_flash_attn(&[FlashAttn::V1, FlashAttn::V2]) {
tracing::info!("Starting FlashPplx1 model on {:?}", device);
Ok(Box::new(
FlashPplx1Model::load(vb, &config, model_type).s()?,
))
} else {
tracing::info!("Starting Pplx1 model on {:?}", device);
tracing::info!("Starting Pplx model on {:?}", device);
Ok(Box::new(Pplx1Model::load(vb, &config, model_type).s()?))
}
}
Expand Down
57 changes: 57 additions & 0 deletions backends/candle/src/models/flash_pplx1.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use crate::models::{Model, Pplx1Config, FlashQwen3Model};
use candle::{Result, Tensor};
use candle_nn::VarBuilder;
use text_embeddings_backend_core::{Batch, ModelType, Pool};

pub struct FlashPplx1Model {
inner: FlashQwen3Model,
}

impl FlashPplx1Model {
pub fn load(vb: VarBuilder, config: &Pplx1Config, model_type: ModelType) -> Result<Self> {
match model_type {
ModelType::Classifier => {
candle::bail!("`classifier` model type is not supported for Pplx1")
}
ModelType::Embedding(ref pool) => {
if pool != &Pool::Mean {
candle::bail!("Pplx1 only supports mean pooling, got {:?}", pool);
}
}
};

// NOTE: Qwen3 but the `config` contains `use_bidirectional_attention=true`
let inner = FlashQwen3Model::load(vb, config, model_type)?;

Ok(Self { inner })
}

pub fn forward(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
let (pooled, raw) = self.inner.forward(batch)?;

// NOTE: Apply Pplx1-specific quantization to pooled embeddings
let pooled = pooled
.map(|embeddings| {
embeddings
.tanh() // Apply tanh: [-1, 1]
// NOTE: To benefit form the INT8 quantization / scaling, the `normalize`
// parameter when generating embeddings should be set to `false`, otherwise the
// quantization is "lost"
.and_then(|t| t.affine(127.0, 0.0)) // INT8 scale: [-127, 127]
.and_then(|t| t.round()) // Round to integers
})
.transpose()?;

Ok((pooled, raw))
}
}

impl Model for Pplx1Model {
fn is_padded(&self) -> bool {
self.inner.is_padded()
}

fn embed(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
self.forward(batch)
}
}