Skip to content

Commit cd54189

Browse files
committed
deploy: 5d2aa76
1 parent 9a4299c commit cd54189

3 files changed

Lines changed: 42 additions & 18 deletions

File tree

_sources/index.rst.txt

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,26 @@ Define a custom flow — centroid-based block-sparse routing in a dozen lines:
3838

3939
.. code-block:: python
4040
41+
from typing import Dict
42+
import torch
43+
44+
from vortex_torch.flow import vFlow, register
45+
from vortex_torch.indexer import GeMM, Mean, topK
46+
from vortex_torch.cache import Mean as CMean
47+
from vortex_torch.abs import ContextBase
48+
49+
4150
@register("custom_sparse_attention")
4251
class CustomSparseAttention(vFlow):
4352
4453
def __init__(self):
4554
super().__init__()
4655
# Indexer-side ops (run every decode step)
47-
self.gemv = GeMV()
48-
self.output_func = topK()
56+
self.mean = Mean(dim=1) # average over the query heads
57+
self.gemm = GeMM() # GeMM(x, y) = y @ xᵀ
58+
self.output_func = topK() # must end in topK / approxTopK
4959
# Cache-side ops (run once per finished page)
50-
self.reduction = CMean(dim=1)
60+
self.reduction = CMean(dim=1) # one centroid (mean key) per page
5161
5262
def forward_indexer(
5363
self,
@@ -56,9 +66,10 @@ Define a custom flow — centroid-based block-sparse routing in a dozen lines:
5666
cache: Dict[str, torch.Tensor], # viewed as [S, r, c] per create_cache()
5767
ctx: ContextBase,
5868
):
59-
q_mean = self.mean(q, ctx=ctx)
60-
score = self.gemv(q_mean, cache["centroids"], ctx=ctx)
61-
self.output_func(score, o, ctx=ctx) # must end in topK / approxTopK
69+
# No native torch ops here — every tensor flows through vortex ops.
70+
q_mean = self.mean(q, ctx=ctx) # [B, 1, D]
71+
score = self.gemm(q_mean, cache["centroids"], ctx=ctx) # [S, 1, 1]
72+
self.output_func(score, o, ctx=ctx) # selected pages -> o
6273
6374
def forward_cache(
6475
self,
@@ -69,7 +80,7 @@ Define a custom flow — centroid-based block-sparse routing in a dozen lines:
6980
# triggered only when a page is finished
7081
self.reduction(cache["k"], cache["centroids"], loc=loc, ctx=ctx)
7182
72-
def create_cache(self, page_size: int, head_dim: int):
83+
def create_cache(self, block_size: int, head_dim: int):
7384
# "k" and "v" are provided automatically — do not declare them
7485
return {"centroids": (1, head_dim)}
7586
@@ -80,7 +91,8 @@ Then run it through an SGLang engine:
8091
llm = sgl.Engine(
8192
model_path="Qwen/Qwen3-0.6B",
8293
page_size=16,
83-
attention_backend="flashinfer", # SGLang's base backend
94+
attention_backend="flashinfer", # mandatory: SGLang's base backend
95+
disable_overlap_schedule=True, # mandatory for vortex sparsity
8496
enable_vortex_sparsity=True, # otherwise computes full attention
8597
vortex_topk_val=30, # pages kept per request
8698
vortex_block_reserved_bos=1, # always-attended prefix blocks

index.html

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -299,16 +299,26 @@ <h2>Installation<a class="headerlink" href="#installation" title="Link to this h
299299
<section id="quick-example">
300300
<h2>Quick Example<a class="headerlink" href="#quick-example" title="Link to this heading"></a></h2>
301301
<p>Define a custom flow — centroid-based block-sparse routing in a dozen lines:</p>
302-
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="nd">@register</span><span class="p">(</span><span class="s2">&quot;custom_sparse_attention&quot;</span><span class="p">)</span>
302+
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span><span class="w"> </span><span class="nn">typing</span><span class="w"> </span><span class="kn">import</span> <span class="n">Dict</span>
303+
<span class="kn">import</span><span class="w"> </span><span class="nn">torch</span>
304+
305+
<span class="kn">from</span><span class="w"> </span><span class="nn">vortex_torch.flow</span><span class="w"> </span><span class="kn">import</span> <span class="n">vFlow</span><span class="p">,</span> <span class="n">register</span>
306+
<span class="kn">from</span><span class="w"> </span><span class="nn">vortex_torch.indexer</span><span class="w"> </span><span class="kn">import</span> <span class="n">GeMM</span><span class="p">,</span> <span class="n">Mean</span><span class="p">,</span> <span class="n">topK</span>
307+
<span class="kn">from</span><span class="w"> </span><span class="nn">vortex_torch.cache</span><span class="w"> </span><span class="kn">import</span> <span class="n">Mean</span> <span class="k">as</span> <span class="n">CMean</span>
308+
<span class="kn">from</span><span class="w"> </span><span class="nn">vortex_torch.abs</span><span class="w"> </span><span class="kn">import</span> <span class="n">ContextBase</span>
309+
310+
311+
<span class="nd">@register</span><span class="p">(</span><span class="s2">&quot;custom_sparse_attention&quot;</span><span class="p">)</span>
303312
<span class="k">class</span><span class="w"> </span><span class="nc">CustomSparseAttention</span><span class="p">(</span><span class="n">vFlow</span><span class="p">):</span>
304313

305314
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
306315
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
307316
<span class="c1"># Indexer-side ops (run every decode step)</span>
308-
<span class="bp">self</span><span class="o">.</span><span class="n">gemv</span> <span class="o">=</span> <span class="n">GeMV</span><span class="p">()</span>
309-
<span class="bp">self</span><span class="o">.</span><span class="n">output_func</span> <span class="o">=</span> <span class="n">topK</span><span class="p">()</span>
317+
<span class="bp">self</span><span class="o">.</span><span class="n">mean</span> <span class="o">=</span> <span class="n">Mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># average over the query heads</span>
318+
<span class="bp">self</span><span class="o">.</span><span class="n">gemm</span> <span class="o">=</span> <span class="n">GeMM</span><span class="p">()</span> <span class="c1"># GeMM(x, y) = y @ xᵀ</span>
319+
<span class="bp">self</span><span class="o">.</span><span class="n">output_func</span> <span class="o">=</span> <span class="n">topK</span><span class="p">()</span> <span class="c1"># must end in topK / approxTopK</span>
310320
<span class="c1"># Cache-side ops (run once per finished page)</span>
311-
<span class="bp">self</span><span class="o">.</span><span class="n">reduction</span> <span class="o">=</span> <span class="n">CMean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
321+
<span class="bp">self</span><span class="o">.</span><span class="n">reduction</span> <span class="o">=</span> <span class="n">CMean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># one centroid (mean key) per page</span>
312322

313323
<span class="k">def</span><span class="w"> </span><span class="nf">forward_indexer</span><span class="p">(</span>
314324
<span class="bp">self</span><span class="p">,</span>
@@ -317,9 +327,10 @@ <h2>Quick Example<a class="headerlink" href="#quick-example" title="Link to this
317327
<span class="n">cache</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="c1"># viewed as [S, r, c] per create_cache()</span>
318328
<span class="n">ctx</span><span class="p">:</span> <span class="n">ContextBase</span><span class="p">,</span>
319329
<span class="p">):</span>
320-
<span class="n">q_mean</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">)</span>
321-
<span class="n">score</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">gemv</span><span class="p">(</span><span class="n">q_mean</span><span class="p">,</span> <span class="n">cache</span><span class="p">[</span><span class="s2">&quot;centroids&quot;</span><span class="p">],</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">)</span>
322-
<span class="bp">self</span><span class="o">.</span><span class="n">output_func</span><span class="p">(</span><span class="n">score</span><span class="p">,</span> <span class="n">o</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">)</span> <span class="c1"># must end in topK / approxTopK</span>
330+
<span class="c1"># No native torch ops here — every tensor flows through vortex ops.</span>
331+
<span class="n">q_mean</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">)</span> <span class="c1"># [B, 1, D]</span>
332+
<span class="n">score</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">gemm</span><span class="p">(</span><span class="n">q_mean</span><span class="p">,</span> <span class="n">cache</span><span class="p">[</span><span class="s2">&quot;centroids&quot;</span><span class="p">],</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">)</span> <span class="c1"># [S, 1, 1]</span>
333+
<span class="bp">self</span><span class="o">.</span><span class="n">output_func</span><span class="p">(</span><span class="n">score</span><span class="p">,</span> <span class="n">o</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">)</span> <span class="c1"># selected pages -&gt; o</span>
323334

324335
<span class="k">def</span><span class="w"> </span><span class="nf">forward_cache</span><span class="p">(</span>
325336
<span class="bp">self</span><span class="p">,</span>
@@ -330,7 +341,7 @@ <h2>Quick Example<a class="headerlink" href="#quick-example" title="Link to this
330341
<span class="c1"># triggered only when a page is finished</span>
331342
<span class="bp">self</span><span class="o">.</span><span class="n">reduction</span><span class="p">(</span><span class="n">cache</span><span class="p">[</span><span class="s2">&quot;k&quot;</span><span class="p">],</span> <span class="n">cache</span><span class="p">[</span><span class="s2">&quot;centroids&quot;</span><span class="p">],</span> <span class="n">loc</span><span class="o">=</span><span class="n">loc</span><span class="p">,</span> <span class="n">ctx</span><span class="o">=</span><span class="n">ctx</span><span class="p">)</span>
332343

333-
<span class="k">def</span><span class="w"> </span><span class="nf">create_cache</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">page_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">head_dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
344+
<span class="k">def</span><span class="w"> </span><span class="nf">create_cache</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">block_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">head_dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
334345
<span class="c1"># &quot;k&quot; and &quot;v&quot; are provided automatically — do not declare them</span>
335346
<span class="k">return</span> <span class="p">{</span><span class="s2">&quot;centroids&quot;</span><span class="p">:</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">head_dim</span><span class="p">)}</span>
336347
</pre></div>
@@ -339,7 +350,8 @@ <h2>Quick Example<a class="headerlink" href="#quick-example" title="Link to this
339350
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">llm</span> <span class="o">=</span> <span class="n">sgl</span><span class="o">.</span><span class="n">Engine</span><span class="p">(</span>
340351
<span class="n">model_path</span><span class="o">=</span><span class="s2">&quot;Qwen/Qwen3-0.6B&quot;</span><span class="p">,</span>
341352
<span class="n">page_size</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span>
342-
<span class="n">attention_backend</span><span class="o">=</span><span class="s2">&quot;flashinfer&quot;</span><span class="p">,</span> <span class="c1"># SGLang&#39;s base backend</span>
353+
<span class="n">attention_backend</span><span class="o">=</span><span class="s2">&quot;flashinfer&quot;</span><span class="p">,</span> <span class="c1"># mandatory: SGLang&#39;s base backend</span>
354+
<span class="n">disable_overlap_schedule</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="c1"># mandatory for vortex sparsity</span>
343355
<span class="n">enable_vortex_sparsity</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="c1"># otherwise computes full attention</span>
344356
<span class="n">vortex_topk_val</span><span class="o">=</span><span class="mi">30</span><span class="p">,</span> <span class="c1"># pages kept per request</span>
345357
<span class="n">vortex_block_reserved_bos</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="c1"># always-attended prefix blocks</span>

searchindex.js

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)