Skip to content
Closed
2 changes: 2 additions & 0 deletions dflash/src/common/step_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ struct StepGraph {
ggml_tensor * target_hidden_cat = nullptr; // draft only
ggml_tensor * positions_k = nullptr; // draft only
ggml_tensor * hidden_input = nullptr; // lm-head projection only
ggml_tensor * sfi_gather_idx = nullptr; // SFI sparse indices [sfi_budget, n_head_kv] i32; null when disabled

// Output
ggml_tensor * logits = nullptr;
Expand All @@ -48,6 +49,7 @@ inline void step_graph_free(StepGraph & sg) {
sg.inp_embed = sg.positions = sg.attn_mask = nullptr;
sg.target_hidden_cat = sg.positions_k = nullptr;
sg.hidden_input = nullptr;
sg.sfi_gather_idx = nullptr;
sg.parent_ids = nullptr;
sg.logits = nullptr;
sg.hidden_states = nullptr;
Expand Down
14 changes: 13 additions & 1 deletion dflash/src/internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,14 @@ struct TargetCache {
// cast (ggml_get_to_fp32_cuda).
ggml_tensor * target_feat = nullptr;
int target_feat_cap = 0;

// ── SFI (Slow-Fast Inference) selector state ─────────────────────
// Per full-attn layer importance scores, updated on slow-refresh steps.
// Shape: [max_ctx] f32 per layer. Used to select Top-K indices for
// sparse fast-step attention.
std::vector<std::vector<float>> sfi_selector; // size = n_full_attn (16)
std::vector<std::vector<int>> sfi_selected; // cached merged indices
int sfi_budget = 0; // 0 = disabled; >0 = sparse token budget
};

// Snapshot the current SSM+conv state into TargetCache::*_snap tensors.
Expand Down Expand Up @@ -490,6 +498,8 @@ struct QwenGraphInputs {
int fa_window = 0; // sliding window for FA layers: 0 = full attention
bool last_token_logits_only = false; // if true, only compute logits for last token (prefill optimization)
ggml_tensor * parent_ids = nullptr; // [n_tokens] i32; tree mode when non-null
ggml_tensor * sfi_gather_idx = nullptr; // [sfi_budget, n_head_kv] i32; repeated sparse indices per KV head
int sfi_gather_len = 0; // number of valid indices in sfi_gather_idx
};

struct QwenGraphOutputs {
Expand Down Expand Up @@ -525,7 +535,9 @@ ggml_tensor * build_qwen35_layer(
bool capture,
int fa_window = 0,
ggml_tensor * q_tail_capture = nullptr,
int q_tail_start = 0);
int q_tail_start = 0,
ggml_tensor * sfi_gather_idx = nullptr,
int sfi_gather_len = 0);

} // namespace dflash27b

Expand Down
9 changes: 9 additions & 0 deletions dflash/src/qwen35/graph_builders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,15 @@ bool build_target_step(
gi.fa_window = fa_window;
gi.last_token_logits_only = last_token_logits_only;

if (cache.sfi_budget > 0 && n_tokens == 1 && !with_mask) {
sg.sfi_gather_idx = ggml_new_tensor_2d(
sg.ctx, GGML_TYPE_I32, cache.sfi_budget, w.n_head_kv);
ggml_set_name(sg.sfi_gather_idx, "sfi_gather_idx");
ggml_set_input(sg.sfi_gather_idx);
gi.sfi_gather_idx = sg.sfi_gather_idx;
gi.sfi_gather_len = cache.sfi_budget;
}

QwenGraphOutputs go = build_qwen35_graph(sg.ctx, sg.gf, w, cache, gi);
if (!go.logits) return false;
sg.logits = go.logits;
Expand Down
72 changes: 67 additions & 5 deletions dflash/src/qwen35/qwen35_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "common/sampler.h"
#include "common/io_utils.h"
#include "qwen3/qwen3_drafter.h"
#include "sfi_decode_utils.h"

#include "ggml-cuda.h"

Expand All @@ -22,6 +23,44 @@ namespace dflash27b {
( ((w).eos_chat_id >= 0 && (tok) == (w).eos_chat_id) \
|| ((w).eos_id >= 0 && (tok) == (w).eos_id ) )

namespace {

void sfi_refresh_heuristic(TargetCache & cache, int kv_len) {
if (cache.sfi_budget <= 0 || kv_len <= 0 || cache.sfi_selector.empty()) return;
for (auto & scores : cache.sfi_selector) {
sfi::refresh_selector_heuristic(scores, kv_len);
}
if (!cache.sfi_selected.empty()) {
cache.sfi_selected[0] = sfi::compute_sfi_indices(
cache.sfi_selector[0], kv_len, cache.sfi_budget);
}
}

bool sfi_fill_indices(StepGraph & sg, TargetCache & cache, int kv_len, int n_head_kv) {
if (!sg.sfi_gather_idx || cache.sfi_budget <= 0 || kv_len <= 0) return false;
if (cache.sfi_selected.empty() || cache.sfi_selected[0].empty()) {
if (cache.sfi_selector.empty() || cache.sfi_selector[0].empty()) return false;
cache.sfi_selected[0] = sfi::compute_sfi_indices(
cache.sfi_selector[0], kv_len, cache.sfi_budget);
}

const auto & idx = cache.sfi_selected[0];
if (idx.empty()) return false;

std::vector<int32_t> idx_buf((size_t)cache.sfi_budget * n_head_kv, (int32_t)idx.back());
const int n = std::min((int)idx.size(), cache.sfi_budget);
for (int h = 0; h < n_head_kv; ++h) {
int32_t * head_idx = idx_buf.data() + (size_t)h * cache.sfi_budget;
for (int i = 0; i < n; ++i) head_idx[i] = idx[i];
}

ggml_backend_tensor_set(sg.sfi_gather_idx, idx_buf.data(), 0,
sizeof(int32_t) * idx_buf.size());
return true;
}

} // namespace

// ── Construction / destruction ──────────────────────────────────────────

Qwen35Backend::Qwen35Backend(const Qwen35Config & cfg) : cfg_(cfg) {}
Expand Down Expand Up @@ -453,6 +492,12 @@ int Qwen35Backend::do_prefill(const std::vector<int32_t> & tokens,
return -1;
}

int32_t last_tok = -1;
const size_t argmax_off =
(start + n_tokens < prompt_len) ? 0 : sizeof(int32_t) * (size_t)(n_tokens - 1);
ggml_backend_tensor_get(sg_.argmax_tokens, &last_tok, argmax_off, sizeof(int32_t));
cache_.last_tok = last_tok;

// Snapshot at boundary if requested
if (snap_pos >= 0 && snap_slot >= 0 &&
start + n_tokens >= snap_pos && start < snap_pos) {
Expand All @@ -473,6 +518,8 @@ int Qwen35Backend::do_prefill(const std::vector<int32_t> & tokens,
}
}

sfi_refresh_heuristic(cache_, committed);

return committed;
}

Expand All @@ -492,10 +539,16 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen,

const int hidden = w_.n_embd;
const int vocab = w_.n_vocab;
const int refresh_interval = sfi::parse_env_int("DFLASH27B_FA_REFRESH_INTERVAL", 0);
std::vector<float> logits_buf(vocab);
std::vector<float> embed_buf_vec(hidden);
float * embed_buf = embed_buf_vec.data();

if (cache_.sfi_budget > 0 && committed > 0 &&
(cache_.sfi_selected.empty() || cache_.sfi_selected[0].empty())) {
sfi_refresh_heuristic(cache_, committed);
}

for (int i = 0; i < n_gen; i++) {
if (!build_target_step(sg_, w_, cache_, target_backend_,
/*kv_start=*/committed, /*n_tokens=*/1,
Expand All @@ -507,20 +560,24 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen,
return false;
}

sfi_fill_indices(sg_, cache_, committed, w_.n_head_kv);

// Get last generated token (or first prompt token for first iter)
int32_t tok = out_tokens.empty()
? 0 // Should not happen — prefill emits at least one logit
? cache_.last_tok
: out_tokens.back();

if (i == 0 && out_tokens.empty()) {
// First decode: read argmax from prefill's last logits
int32_t argmax = 0;
ggml_backend_tensor_get(sg_.argmax_tokens, &argmax, 0, sizeof(int32_t));
tok = argmax;
if (tok < 0) return false;
out_tokens.push_back(tok);
io.emit(tok);
if (IS_EOS_TOK(tok, w_)) { io.emit(-1); return true; }
committed++;
cache_.cur_pos = committed;
cache_.last_tok = tok;
if (refresh_interval > 0 && (committed % refresh_interval) == 0) {
sfi_refresh_heuristic(cache_, committed);
}
continue;
}

Expand Down Expand Up @@ -551,6 +608,11 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen,
io.emit(next_tok);
committed++;
cache_.cur_pos = committed;
cache_.last_tok = next_tok;

if (refresh_interval > 0 && (committed % refresh_interval) == 0) {
sfi_refresh_heuristic(cache_, committed);
}

if (IS_EOS_TOK(next_tok, w_)) break;
}
Expand Down
78 changes: 58 additions & 20 deletions dflash/src/qwen35/qwen35_target_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "internal.h"
#include "delta_net_chunked.h"
#include "kv_quant.h"
#include "sfi_decode_utils.h"

#include <cmath>
#include <cstdio>
Expand Down Expand Up @@ -117,6 +118,14 @@ bool create_target_cache_partial(const TargetWeights & w,
out.ssm_intermediate.assign(n_delta, nullptr);
out.conv_input_cache.assign(n_delta, nullptr);

// SFI selector state: per-layer importance scores (host-side, zero-initialized).
{
const char * sfi_env = std::getenv("DFLASH27B_SFI_BUDGET");
out.sfi_budget = (sfi_env && *sfi_env) ? std::max(0, std::atoi(sfi_env)) : 0;
}
out.sfi_selector.assign(n_full_attn, std::vector<float>(max_ctx, 0.0f));
out.sfi_selected.assign(n_full_attn, std::vector<int>{});

// KV cache element types (resolved from env; aborts on unsupported pair).
ggml_type kv_k_type = GGML_TYPE_Q8_0;
ggml_type kv_v_type = GGML_TYPE_Q8_0;
Expand Down Expand Up @@ -433,6 +442,13 @@ static ggml_tensor * build_swiglu_ffn(ggml_context * ctx, ggml_tensor * cur,
return apply_scale2(ctx, ggml_mul_mat(ctx, L.w_down, gu), L.w_down_s); // [hidden, n_tokens]
}

static int parse_fa_refresh_interval() {
const char * raw = std::getenv("DFLASH27B_FA_REFRESH_INTERVAL");
if (!raw || !*raw) return 0;
int v = std::atoi(raw);
return v > 0 ? v : 0;
}

// Full-attention block (matches llama.cpp's build_layer_attn for qwen35)
//
// `cache_k` / `cache_v` are the persistent KV buffers for this layer
Expand All @@ -457,7 +473,9 @@ static ggml_tensor * build_full_attn_block(
bool kv_k_rotated = false,
int fa_window = 0,
ggml_tensor * q_tail_capture = nullptr,
int q_tail_start = 0
int q_tail_start = 0,
ggml_tensor * sfi_gather_idx = nullptr, // [n_sparse] i32 indices for sparse gather
int sfi_gather_len = 0 // number of sparse indices
) {
const int head_dim = w.n_embd_head_k;
const int n_head = w.n_head;
Expand Down Expand Up @@ -569,13 +587,12 @@ static ggml_tensor * build_full_attn_block(
// When fa_window > 0 and kv_start >= fa_window, only attend to the last
// fa_window positions. This dramatically reduces FA cost during speculative
// decode verify/replay at long contexts (60K+ kv entries).
const int win_start = (fa_window > 0 && kv_start > fa_window)
? (kv_start - fa_window) : 0;
const int kv_len = kv_start + n_tokens;
const int win_len = kv_len - win_start;

const int fattn_stride = (kv_k_type == GGML_TYPE_TQ3_0 || kv_v_type == GGML_TYPE_TQ3_0) ? 256 : 1;
const int win_len_padded = ((win_len + fattn_stride - 1) / fattn_stride) * fattn_stride;
const sfi::AttnWindowSlice ws = sfi::resolve_attn_window_slice(
kv_start, n_tokens,
/*allow_slow_refresh=*/attn_mask == nullptr,
fa_window,
parse_fa_refresh_interval(),
/*uses_256_stride=*/kv_k_type == GGML_TYPE_TQ3_0 || kv_v_type == GGML_TYPE_TQ3_0);

ggml_tensor * Qfa = ggml_permute(ctx, Q, 0, 2, 1, 3);
// When K is rotated (TQ3_0 or explicit FWHT), Q needs forward rotation too.
Expand All @@ -589,13 +606,22 @@ static ggml_tensor * build_full_attn_block(
Qfa = ggml_cont(ctx, Qfa);
}

// K and V from cache: a windowed view starting at win_start.
ggml_tensor * Kfa = ggml_view_3d(ctx, cache_k,
head_dim, win_len_padded, n_head_kv,
cache_k->nb[1], cache_k->nb[2], cache_k->nb[1] * win_start);
ggml_tensor * Vfa = ggml_view_3d(ctx, cache_v,
head_dim, win_len_padded, n_head_kv,
cache_v->nb[1], cache_v->nb[2], cache_v->nb[1] * win_start);
// K and V from cache: either sparse gather (SFI fast step) or windowed view.
ggml_tensor * Kfa;
ggml_tensor * Vfa;
if (sfi_gather_idx && sfi_gather_len > 0 && !ws.used_slow_refresh) {
// SFI fast step: gather only the selected sparse positions.
Kfa = ggml_get_rows(ctx, cache_k, sfi_gather_idx);
Vfa = ggml_get_rows(ctx, cache_v, sfi_gather_idx);
} else {
// Standard windowed view.
Kfa = ggml_view_3d(ctx, cache_k,
head_dim, ws.win_len_padded, n_head_kv,
cache_k->nb[1], cache_k->nb[2], cache_k->nb[1] * ws.win_start);
Vfa = ggml_view_3d(ctx, cache_v,
head_dim, ws.win_len_padded, n_head_kv,
cache_v->nb[1], cache_v->nb[2], cache_v->nb[1] * ws.win_start);
}

// Causal mask: for n_tokens==1 we don't need one (a single query attending
// to all keys is trivially causal). For n_tokens>1 the caller must provide
Expand Down Expand Up @@ -918,7 +944,9 @@ static ggml_tensor * build_single_layer(
bool capture,
int fa_window = 0,
ggml_tensor * q_tail_capture = nullptr,
int q_tail_start = 0)
int q_tail_start = 0,
ggml_tensor * sfi_gather_idx = nullptr,
int sfi_gather_len = 0)
{
const int hidden = w.n_embd;
const float eps = w.rms_eps;
Expand All @@ -942,7 +970,8 @@ static ggml_tensor * build_single_layer(
cache.kv_k_type, cache.kv_v_type,
cache.kv_k_rotated,
fa_window,
q_tail_capture, q_tail_start);
q_tail_capture, q_tail_start,
sfi_gather_idx, sfi_gather_len);
} else {
int dn_idx = 0;
for (int il = 0; il < layer_idx; il++) {
Expand Down Expand Up @@ -1032,6 +1061,10 @@ QwenGraphOutputs build_qwen35_graph(
const int hidden = w.n_embd;
const float eps = w.rms_eps;

// SFI sparse gather: use caller-provided index tensor for fast decode steps.
ggml_tensor * sfi_idx_tensor = in.sfi_gather_idx;
int sfi_gather_len = in.sfi_gather_len;

for (int il = 0; il < w.n_layer; il++) {
const TargetLayer & L = w.layers[il];
const bool is_attn = (((il + 1) % w.full_attention_interval) == 0);
Expand All @@ -1047,7 +1080,9 @@ QwenGraphOutputs build_qwen35_graph(
in.attn_mask, in.kv_start, n_tokens,
cache.kv_k_type, cache.kv_v_type,
cache.kv_k_rotated,
in.fa_window);
in.fa_window,
nullptr, 0,
sfi_idx_tensor, sfi_gather_len);
fa_idx++;
} else {
DeltaNetCapture * cap_ptr = nullptr;
Expand Down Expand Up @@ -1166,11 +1201,14 @@ ggml_tensor * build_qwen35_layer(
bool capture,
int fa_window,
ggml_tensor * q_tail_capture,
int q_tail_start)
int q_tail_start,
ggml_tensor * sfi_gather_idx,
int sfi_gather_len)
{
return build_single_layer(ctx, gf, w, cache, layer_idx, inp, positions,
attn_mask, kv_start, n_tokens, capture, fa_window,
q_tail_capture, q_tail_start);
q_tail_capture, q_tail_start,
sfi_gather_idx, sfi_gather_len);
}

// ─── Cross-request prefix snapshot (Phase A) ─────────────────────────
Expand Down
Loading
Loading