Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion candle-examples/examples/gemma4/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,12 @@ fn main() -> Result<()> {
}
};
config.use_flash_attn = args.use_flash_attn;
let model = TextModel::new(&config, vb)?;
let text_vb = if vb.contains_tensor("model.language_model.embed_tokens.weight") {
vb.pp("model").pp("language_model")
} else {
vb.pp("model")
};
let model = TextModel::new(&config, text_vb)?;
ModelKind::TextOnly(model)
};

Expand Down
94 changes: 94 additions & 0 deletions candle-transformers/src/models/gemma4/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ fn default_global_head_dim() -> usize {
fn default_use_flash_attn() -> bool {
false
}
fn default_vocab_size_per_layer_input() -> usize {
262144
}
fn default_hidden_size_per_layer_input() -> usize {
0
}

// ── Rope parameters ─────────────────────────────────────────────────────────

Expand Down Expand Up @@ -109,6 +115,14 @@ pub struct Gemma4TextConfig {
pub use_bidirectional_attention: Option<String>,
#[serde(default = "default_use_flash_attn")]
pub use_flash_attn: bool,
#[serde(default = "default_vocab_size_per_layer_input")]
pub vocab_size_per_layer_input: usize,
#[serde(default = "default_hidden_size_per_layer_input")]
pub hidden_size_per_layer_input: usize,
#[serde(default)]
pub num_kv_shared_layers: usize,
#[serde(default)]
pub use_double_wide_mlp: bool,
}

impl Gemma4TextConfig {
Expand Down Expand Up @@ -142,6 +156,36 @@ impl Gemma4TextConfig {
.map(|s| s == "sliding_attention")
.unwrap_or(false)
}

pub fn is_kv_shared_layer(&self, layer_idx: usize) -> bool {
if self.num_kv_shared_layers == 0 {
return false;
}
let first_kv_shared_layer = self
.num_hidden_layers
.saturating_sub(self.num_kv_shared_layers);
layer_idx >= first_kv_shared_layer
}

pub fn uses_double_wide_mlp(&self, layer_idx: usize) -> bool {
self.use_double_wide_mlp && self.is_kv_shared_layer(layer_idx)
}

pub fn stores_shared_kv(&self, layer_idx: usize) -> bool {
if self.is_kv_shared_layer(layer_idx) {
return false;
}
let first_kv_shared_layer = self
.num_hidden_layers
.saturating_sub(self.num_kv_shared_layers);
let Some(layer_type) = self.layer_types.get(layer_idx) else {
return false;
};
!self.layer_types[..first_kv_shared_layer]
.iter()
.skip(layer_idx + 1)
.any(|next| next == layer_type)
}
}

// ── Vision config defaults ──────────────────────────────────────────────────
Expand Down Expand Up @@ -397,3 +441,53 @@ pub struct Gemma4Config {
#[serde(default = "default_video_token_id")]
pub video_token_id: usize,
}

#[cfg(test)]
mod tests {
use super::Gemma4TextConfig;

#[test]
fn gemma4_e2b_text_config_enables_ple_and_shared_kv() {
let json = r#"{
"attention_bias": false,
"head_dim": 256,
"hidden_activation": "gelu_pytorch_tanh",
"hidden_size": 1536,
"intermediate_size": 6144,
"num_attention_heads": 8,
"num_hidden_layers": 35,
"num_key_value_heads": 4,
"sliding_window": 512,
"layer_types": [
"sliding_attention", "sliding_attention", "sliding_attention",
"sliding_attention", "full_attention", "sliding_attention",
"sliding_attention", "sliding_attention", "sliding_attention",
"full_attention", "sliding_attention", "sliding_attention",
"sliding_attention", "sliding_attention", "full_attention",
"sliding_attention", "sliding_attention", "sliding_attention",
"sliding_attention", "full_attention", "sliding_attention",
"sliding_attention", "sliding_attention", "sliding_attention",
"full_attention", "sliding_attention", "sliding_attention",
"sliding_attention", "sliding_attention", "full_attention",
"sliding_attention", "sliding_attention", "sliding_attention",
"sliding_attention", "full_attention"
],
"num_kv_shared_layers": 20,
"use_double_wide_mlp": true,
"hidden_size_per_layer_input": 256,
"vocab_size_per_layer_input": 262144
}"#;
let cfg: Gemma4TextConfig =
serde_json::from_str(json).expect("valid Gemma 4 E2B text config");

assert_eq!(cfg.hidden_size_per_layer_input, 256);
assert!(!cfg.is_kv_shared_layer(14));
assert!(cfg.is_kv_shared_layer(15));
assert!(!cfg.uses_double_wide_mlp(14));
assert!(cfg.uses_double_wide_mlp(15));
assert!(cfg.stores_shared_kv(13));
assert!(cfg.stores_shared_kv(14));
assert!(!cfg.stores_shared_kv(12));
assert!(!cfg.stores_shared_kv(15));
}
}
Loading