Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions candle-examples/examples/llama/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use clap::{Parser, ValueEnum};
use candle::{DType, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use hf_hub::{api::sync::Api, Repo, RepoType};
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
use std::io::Write;

use candle_transformers::models::llama as model;
Expand Down Expand Up @@ -55,6 +55,8 @@ enum Which {
SmolLM2_135M,
#[value(name = "SmoLM2-135M-Instruct")]
SmolLM2_135MInstruct,
#[value(name = "mn-violet-lotus")]
MNVioletLotus,
}

#[derive(Parser, Debug)]
Expand Down Expand Up @@ -145,7 +147,7 @@ fn main() -> Result<()> {
None => DType::F16,
};
let (llama, tokenizer_filename, mut cache, config) = {
let api = Api::new()?;
let api = ApiBuilder::from_env().with_progress(true).build()?;
let model_id = args.model_id.unwrap_or_else(|| {
let str = match args.which {
Which::V1 => "Narsil/amall-7b",
Expand All @@ -166,6 +168,7 @@ fn main() -> Result<()> {
Which::SmolLM2_360MInstruct => "HuggingFaceTB/SmolLM2-360M-Instruct",
Which::SmolLM2_1B => "HuggingFaceTB/SmolLM2-1.7B",
Which::SmolLM2_1BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct",
Which::MNVioletLotus => "FallenMerick/MN-Violet-Lotus-12B",
};
str.to_string()
});
Expand All @@ -187,7 +190,8 @@ fn main() -> Result<()> {
| Which::V31Instruct
| Which::V32_3b
| Which::V32_3bInstruct
| Which::Solar10_7B => {
| Which::Solar10_7B
| Which::MNVioletLotus => {
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
}
Which::SmolLM2_360M
Expand Down
20 changes: 18 additions & 2 deletions candle-examples/examples/quantized/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ enum Which {
Mistral7bInstruct,
#[value(name = "7b-mistral-instruct-v0.2")]
Mistral7bInstructV02,
#[value(name = "12b-mn-violet-lotus")]
MNVioletLotus,
#[value(name = "7b-zephyr-a")]
Zephyr7bAlpha,
#[value(name = "7b-zephyr-b")]
Expand Down Expand Up @@ -104,6 +106,7 @@ impl Which {
| Self::Starling7bAlpha
| Self::Zephyr7bAlpha
| Self::Zephyr7bBeta
| Self::MNVioletLotus
| Self::Mixtral
| Self::MixtralInstruct
| Self::Mistral7b
Expand All @@ -130,6 +133,7 @@ impl Which {
| Self::Mistral7b
| Self::Mistral7bInstruct
| Self::Mistral7bInstructV02
| Self::MNVioletLotus
| Self::OpenChat35
| Self::Starling7bAlpha
| Self::L8b
Expand Down Expand Up @@ -159,6 +163,7 @@ impl Which {
| Self::Mistral7b
| Self::Mistral7bInstruct
| Self::Mistral7bInstructV02
| Self::MNVioletLotus
| Self::Zephyr7bAlpha
| Self::Zephyr7bBeta
| Self::L8b
Expand Down Expand Up @@ -188,6 +193,7 @@ impl Which {
| Self::Mistral7b
| Self::Mistral7bInstruct
| Self::Mistral7bInstructV02
| Self::MNVioletLotus
| Self::Zephyr7bAlpha
| Self::Zephyr7bBeta
| Self::L8b
Expand Down Expand Up @@ -219,6 +225,7 @@ impl Which {
| Self::Mistral7bInstructV02
| Self::Zephyr7bAlpha
| Self::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1",
Self::MNVioletLotus => "FallenMerick/MN-Violet-Lotus-12B",
Self::OpenChat35 => "openchat/openchat_3.5",
Self::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha",
Self::L8b => "meta-llama/Meta-Llama-3-8B",
Expand Down Expand Up @@ -309,7 +316,9 @@ impl Args {
let tokenizer_path = match &self.tokenizer {
Some(config) => std::path::PathBuf::from(config),
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = hf_hub::api::sync::ApiBuilder::from_env()
.with_progress(true)
.build()?;
let repo = self.which.tokenizer_repo();
let api = api.model(repo.to_string());
api.get("tokenizer.json")?
Expand Down Expand Up @@ -369,6 +378,10 @@ impl Args {
"TheBloke/Mistral-7B-Instruct-v0.2-GGUF",
"mistral-7b-instruct-v0.2.Q4_K_S.gguf",
),
Which::MNVioletLotus => (
"backyardai/MN-Violet-Lotus-12B-GGUF",
"MN-Violet-Lotus-12B.Q4_K_M.gguf",
),
Which::Zephyr7bAlpha => (
"TheBloke/zephyr-7B-alpha-GGUF",
"zephyr-7b-alpha.Q4_K_M.gguf",
Expand Down Expand Up @@ -408,7 +421,9 @@ impl Args {
} else {
"main"
};
let api = hf_hub::api::sync::Api::new()?;
let api = hf_hub::api::sync::ApiBuilder::from_env()
.with_progress(true)
.build()?;
api.repo(hf_hub::Repo::with_revision(
repo.to_string(),
hf_hub::RepoType::Model,
Expand Down Expand Up @@ -523,6 +538,7 @@ fn main() -> anyhow::Result<()> {
| Which::Mistral7b
| Which::Mistral7bInstruct
| Which::Mistral7bInstructV02
| Which::MNVioletLotus
| Which::Zephyr7bAlpha
| Which::Zephyr7bBeta
| Which::L70b
Expand Down
29 changes: 21 additions & 8 deletions candle-transformers/src/models/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pub struct LlamaConfig {
pub vocab_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub head_dim: Option<usize>,
pub num_key_value_heads: Option<usize>,
pub rms_norm_eps: f64,
#[serde(default = "default_rope")]
Expand All @@ -54,6 +55,11 @@ pub struct LlamaConfig {
}

impl LlamaConfig {
pub fn head_dim(&self) -> usize {
self.head_dim
.unwrap_or(self.hidden_size / self.num_attention_heads)
}

pub fn num_key_value_heads(&self) -> usize {
self.num_key_value_heads.unwrap_or(self.num_attention_heads)
}
Expand All @@ -71,6 +77,7 @@ impl LlamaConfig {
vocab_size: self.vocab_size,
num_hidden_layers: self.num_hidden_layers,
num_attention_heads: self.num_attention_heads,
head_dim: self.head_dim(),
num_key_value_heads: self.num_key_value_heads(),
rms_norm_eps: self.rms_norm_eps,
rope_theta: self.rope_theta,
Expand All @@ -91,6 +98,7 @@ pub struct Config {
pub vocab_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub head_dim: usize,
pub num_key_value_heads: usize,
pub use_flash_attn: bool,
pub rms_norm_eps: f64,
Expand All @@ -110,6 +118,7 @@ impl Config {
vocab_size: 32000,
num_hidden_layers: 32,
num_attention_heads: 32,
head_dim: 128,
num_key_value_heads: 32,
use_flash_attn,
rms_norm_eps: 1e-6,
Expand All @@ -129,6 +138,7 @@ impl Config {
vocab_size: 32000,
num_hidden_layers: 32,
num_attention_heads: 32,
head_dim: 128,
num_key_value_heads: 32,
use_flash_attn,
rms_norm_eps: 1e-5,
Expand All @@ -153,10 +163,9 @@ pub struct Cache {
}

fn calculate_default_inv_freq(cfg: &Config) -> Vec<f32> {
let head_dim = cfg.hidden_size / cfg.num_attention_heads;
(0..head_dim)
(0..cfg.head_dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
.map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / cfg.head_dim as f32))
.collect()
}

Expand Down Expand Up @@ -275,7 +284,7 @@ impl CausalSelfAttention {
cache: &mut Cache,
) -> Result<Tensor> {
let _enter = self.span.enter();
let (b_sz, seq_len, hidden_size) = x.dims3()?;
let (b_sz, seq_len, _) = x.dims3()?;
let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?;
let v = self.v_proj.forward(x)?;
Expand Down Expand Up @@ -350,7 +359,11 @@ impl CausalSelfAttention {
// Convert to contiguous as matmul doesn't support strided vs for now.
att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?
};
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?;
let y = y.transpose(1, 2)?.reshape(&[
b_sz,
seq_len,
self.num_attention_heads * self.head_dim,
])?;
let y = self.o_proj.forward(&y)?;
Ok(y)
}
Expand All @@ -363,8 +376,8 @@ impl CausalSelfAttention {
let span = tracing::span!(tracing::Level::TRACE, "attn");
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
let size_in = cfg.hidden_size;
let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
let size_q = cfg.head_dim * cfg.num_attention_heads;
let size_kv = cfg.head_dim * cfg.num_key_value_heads;
let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?;
let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?;
let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?;
Expand All @@ -376,7 +389,7 @@ impl CausalSelfAttention {
o_proj,
num_attention_heads: cfg.num_attention_heads,
num_key_value_heads: cfg.num_key_value_heads,
head_dim: cfg.hidden_size / cfg.num_attention_heads,
head_dim: cfg.head_dim,
use_flash_attn: cfg.use_flash_attn,
span,
span_rot,
Expand Down
1 change: 1 addition & 0 deletions candle-transformers/src/models/llava/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ impl LLaVAConfig {
vocab_size: self.vocab_size,
num_hidden_layers: self.num_hidden_layers,
num_attention_heads: self.num_attention_heads,
head_dim: self.hidden_size / self.num_attention_heads,
num_key_value_heads: self.num_key_value_heads,
rms_norm_eps: self.rms_norm_eps as f64,
rope_theta: self.rope_theta,
Expand Down
45 changes: 25 additions & 20 deletions candle-transformers/src/models/quantized_llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ struct LayerWeights {
ffn_norm: RmsNorm,
n_head: usize,
n_kv_head: usize,
head_dim: usize,
k_dim: usize,
v_dim: usize,
/// RoPE convention: true = NEOX (non-interleaved, pairs i with i+d/2),
/// false = NORM (interleaved, pairs 2i with 2i+1).
/// Must match the model architecture — using the wrong convention corrupts
Expand Down Expand Up @@ -195,19 +196,19 @@ impl LayerWeights {
index_pos: usize,
) -> Result<Tensor> {
let _enter = self.span_attn.enter();
let (b_sz, seq_len, n_embd) = x.dims3()?;
let (b_sz, seq_len, _) = x.dims3()?;
let q = self.attention_wq.forward(x)?;
let k = self.attention_wk.forward(x)?;
let v = self.attention_wv.forward(x)?;

let q = q
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
.reshape((b_sz, seq_len, self.n_head, self.k_dim))?
.transpose(1, 2)?;
let k = k
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
.reshape((b_sz, seq_len, self.n_kv_head, self.k_dim))?
.transpose(1, 2)?;
let v = v
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
.reshape((b_sz, seq_len, self.n_kv_head, self.v_dim))?
.transpose(1, 2)?
// This call to contiguous ensures that the fast kernel can be called below. It's
// actually a no-op except when processing the initial prompt so has no significant
Expand All @@ -233,21 +234,13 @@ impl LayerWeights {

let y = if q.device().is_metal() && seq_len == 1 {
// SDPA will do MQA for us
candle_nn::ops::sdpa(
&q,
&k,
&v,
None,
false,
1. / (self.head_dim as f32).sqrt(),
1.,
)?
candle_nn::ops::sdpa(&q, &k, &v, None, false, 1. / (self.v_dim as f32).sqrt(), 1.)?
} else {
// Support for MQA, useful for 70B models and mistral.
let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?;
let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?;

let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let att = (q.matmul(&k.t()?)? / (self.k_dim as f64).sqrt())?;
let att = match mask {
None => att,
Some(mask) => {
Expand All @@ -260,7 +253,9 @@ impl LayerWeights {
att.matmul(&v.contiguous()?)?
};

let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
let y = y
.transpose(1, 2)?
.reshape(&[b_sz, seq_len, self.n_head * self.k_dim])?;
let y = self.attention_wo.forward(&y)?;
Ok(y)
}
Expand Down Expand Up @@ -340,7 +335,8 @@ impl ModelWeights {
ffn_norm: RmsNorm::from_qtensor(ffn_norm, 1e-5)?,
n_head: ct.hparams.n_head as usize,
n_kv_head: ct.hparams.n_head as usize / gqa,
head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
k_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
v_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
rope_is_neox: false, // GGML format = standard Llama = interleaved
cos: cos.clone(),
sin: sin.clone(),
Expand Down Expand Up @@ -381,10 +377,18 @@ impl ModelWeights {
let n_expert_used = md_get("llama.expert_used_count")
.and_then(|v| v.to_u32())
.unwrap_or(0) as usize;
let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize;
let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize;
let block_count = md_get("llama.block_count")?.to_u32()? as usize;
let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize;
let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize;
let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize;
let key_length = md_get("llama.attention.key_length")
.and_then(|m| m.to_u32())
.and_then(|m| Ok(m as usize))
.unwrap_or(embedding_length / head_count);
let value_length = md_get("llama.attention.value_length")
.and_then(|m| m.to_u32())
.and_then(|m| Ok(m as usize))
.unwrap_or(embedding_length / head_count);
let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize;
// Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default.
let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
Expand Down Expand Up @@ -499,7 +503,8 @@ impl ModelWeights {
ffn_norm: RmsNorm::from_qtensor(ffn_norm, rms_norm_eps)?,
n_head: head_count,
n_kv_head: head_count_kv,
head_dim: embedding_length / head_count,
k_dim: key_length,
v_dim: value_length,
rope_is_neox,
cos: cos.clone(),
sin: sin.clone(),
Expand Down