Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 25 additions & 11 deletions server/src/common/kvflash_pager.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,20 +88,28 @@ class KvFlashPager {
cfg.sink_chunks, cfg.tail_window_chunks);
return false;
}
if (attn_k.empty() || attn_k.size() != attn_v.size()) return false;
if (attn_k.size() != attn_v.size()) return false;
cfg_ = cfg;
attn_k_ = attn_k;
attn_v_ = attn_v;
n_blocks_ = cfg.pool_tokens / cfg.chunk_tokens;
const ggml_tensor * K0 = attn_k[0];
if ((int)K0->ne[1] < cfg.pool_tokens) return false;
n_head_kv_ = (int)K0->ne[2];

// Per-(tensor, head) contiguous segment of chunk_tokens rows.
k_seg_bytes_ = (size_t)cfg.chunk_tokens * K0->nb[1];
v_seg_bytes_ = (size_t)cfg.chunk_tokens * attn_v[0]->nb[1];
chunk_bytes_ = (k_seg_bytes_ + v_seg_bytes_) * (size_t)n_head_kv_ * attn_k.size();
zero_buf_.assign(std::max(k_seg_bytes_, v_seg_bytes_), 0);
if (!attn_k.empty()) {
const ggml_tensor * K0 = attn_k[0];
if ((int)K0->ne[1] < cfg.pool_tokens) return false;
n_head_kv_ = (int)K0->ne[2];

// Per-(tensor, head) contiguous segment of chunk_tokens rows.
k_seg_bytes_ = (size_t)cfg.chunk_tokens * K0->nb[1];
v_seg_bytes_ = (size_t)cfg.chunk_tokens * attn_v[0]->nb[1];
chunk_bytes_ = (k_seg_bytes_ + v_seg_bytes_) * (size_t)n_head_kv_ * attn_k.size();
zero_buf_.assign(std::max(k_seg_bytes_, v_seg_bytes_), 0);
} else {
n_head_kv_ = 0;
k_seg_bytes_ = 0;
v_seg_bytes_ = 0;
chunk_bytes_ = 0;
zero_buf_.clear();
}

free_blocks_.clear();
for (int b = n_blocks_ - 1; b >= 0; b--) free_blocks_.push_back(b);
Expand Down Expand Up @@ -195,7 +203,7 @@ class KvFlashPager {
bool page_out(int c) {
if (c >= (int)chunks_.size() || chunks_[c].block < 0) return false;
ChunkState & st = chunks_[c];
if (!st.on_host) {
if (has_tensor_storage() && !st.on_host) {
st.host_data.resize(chunk_bytes_);
stats_.host_bytes += (int64_t)chunk_bytes_;
}
Expand Down Expand Up @@ -343,6 +351,7 @@ class KvFlashPager {
// Move one chunk between pool slots and host backing. Segment order is
// fixed (layer-major, K then V, head-minor) so offsets are stable.
void copy_chunk(int c, int block, bool to_host) {
if (!has_tensor_storage()) return;
ChunkState & st = chunks_[c];
uint8_t * p = st.host_data.data();
for (size_t l = 0; l < attn_k_.size(); l++) {
Expand All @@ -360,6 +369,7 @@ class KvFlashPager {
}

void zero_block(int block) {
if (!has_tensor_storage()) return;
for (size_t l = 0; l < attn_k_.size(); l++) {
for (int kv = 0; kv < 2; kv++) {
ggml_tensor * t = kv == 0 ? attn_k_[l] : attn_v_[l];
Expand All @@ -372,6 +382,10 @@ class KvFlashPager {
}
}

bool has_tensor_storage() const {
return !attn_k_.empty() && chunk_bytes_ > 0;
}

KvFlashConfig cfg_;
std::vector<ggml_tensor *> attn_k_, attn_v_;
std::vector<ChunkState> chunks_;
Expand Down
23 changes: 20 additions & 3 deletions server/src/gemma4/gemma4_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,8 @@ void gemma4_layer_step_graph_free(Gemma4LayerStepGraph & sg) {
sg.token_ids = nullptr;
sg.attn_mask_full = nullptr;
sg.attn_mask_swa = nullptr;
sg.kv_idx_full = nullptr;
sg.kv_idx_swa = nullptr;
}

void gemma4_layer_step_graph_destroy(Gemma4LayerStepGraph & sg) {
Expand All @@ -481,9 +483,11 @@ bool build_gemma4_layer_step(
ggml_tensor * act_out,
int chunk_start,
int n_tokens,
int kv_start) {
int kv_start,
const KvFlashPager * kvflash) {
gemma4_layer_step_graph_free(sg);
if (layer_idx < 0 || layer_idx >= w.n_layer) return false;
if (kvflash && cache.fa_window > 0) return false;

ggml_init_params ip{};
ip.mem_size = ggml_tensor_overhead() * 16384 + ggml_graph_overhead() + 16 * 1024 * 1024;
Expand All @@ -508,8 +512,15 @@ bool build_gemma4_layer_step(
sg.token_ids = ggml_new_tensor_1d(sg.ctx, GGML_TYPE_I32, n_tokens);
ggml_set_input(sg.token_ids);

int full_cap = cache.max_ctx;
for (int il = 0; il < (int)cache.k.size(); ++il) {
if (cache.k[(size_t)il] && !gemma4_is_swa_layer(w, il)) {
full_cap = (int)cache.k[(size_t)il]->ne[1];
break;
}
}
const int kv_len_raw = kv_start + n_tokens;
const int kv_len_padded = (kv_len_raw + 255) & ~255;
const int kv_len_padded = std::min((kv_len_raw + 255) & ~255, full_cap);
sg.attn_mask_full = ggml_new_tensor_4d(
sg.ctx, GGML_TYPE_F32, kv_len_padded, n_tokens, 1, 1);
ggml_set_input(sg.attn_mask_full);
Expand All @@ -524,11 +535,17 @@ bool build_gemma4_layer_step(
ggml_set_input(sg.attn_mask_swa);
ggml_tensor * mask_swa_f16 = ggml_cast(sg.ctx, sg.attn_mask_swa, GGML_TYPE_F16);

sg.kv_idx_full = ggml_new_tensor_1d(sg.ctx, GGML_TYPE_I32, n_tokens);
ggml_set_input(sg.kv_idx_full);
sg.kv_idx_swa = ggml_new_tensor_1d(sg.ctx, GGML_TYPE_I32, n_tokens);
ggml_set_input(sg.kv_idx_swa);

ggml_tensor * pl_input = build_gemma4_per_layer_input(
sg.ctx, w, embed, sg.token_ids, n_tokens, layer_idx);
ggml_tensor * layer_out = build_gemma4_layer(
sg.ctx, sg.gf, w, cache, layer_idx, inp, sg.positions,
mask_full_f16, mask_swa_f16, pl_input, kv_start, n_tokens);
mask_full_f16, mask_swa_f16, pl_input, kv_start, n_tokens,
/*capture_idx=*/-1, sg.kv_idx_full, sg.kv_idx_swa);
if (!layer_out) return false;

ggml_tensor * out_view = ggml_view_2d(
Expand Down
5 changes: 4 additions & 1 deletion server/src/gemma4/gemma4_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ struct Gemma4LayerStepGraph {
ggml_tensor * token_ids = nullptr;
ggml_tensor * attn_mask_full = nullptr;
ggml_tensor * attn_mask_swa = nullptr;
ggml_tensor * kv_idx_full = nullptr;
ggml_tensor * kv_idx_swa = nullptr;
};

void gemma4_layer_step_graph_free(Gemma4LayerStepGraph & sg);
Expand All @@ -296,7 +298,8 @@ bool build_gemma4_layer_step(
ggml_tensor * act_out,
int chunk_start,
int n_tokens,
int kv_start);
int kv_start,
const class KvFlashPager * kvflash = nullptr);

bool compute_gemma4_split_argmax(
ggml_backend_t backend,
Expand Down
Loading
Loading