Skip to content

Fix PagedAttention Scheduler O(N^2) Thrashing#2031

Open
glaziermag wants to merge 2 commits intoEricLBuehler:masterfrom
glaziermag:fix-issue-2024-pa-allocator
Open

Fix PagedAttention Scheduler O(N^2) Thrashing#2031
glaziermag wants to merge 2 commits intoEricLBuehler:masterfrom
glaziermag:fix-issue-2024-pa-allocator

Conversation

@glaziermag
Copy link
Copy Markdown
Contributor

@glaziermag glaziermag commented Mar 26, 2026

Fixes #2024.

Context

The PagedAttentionScheduler previously invoked bucket_and_preempt_sequences() during the Completion scheduling phase. Since PagedAttention backends inherently support variable sequence lengths during decoding via block tables, strict length bucketing during the completion phase is unnecessary.

This behavior caused a severe $O(N^2)$ memory thrashing issue: running completions with misaligned lengths were selectively filtered into non-matching buckets, preempted back to the Waiting state, and their KV caches were evicted. On the subsequent tick, the engine was forced to redundantly re-prefill their entire query contexts, severely degrading continuous batching performance.

Affected Workloads

This issue affected nearly 100% of continuous batching workloads that served >1 concurrent request. Because standard requests inevitably diverge in generated token length during decoding, they quickly drift into different length buckets, triggering the preemption loop. Single-batch workloads (batch size 1) were unaffected.

The Fix

This PR modifies scheduler.rs so that bucket_and_preempt_sequences behaves safely during completions. Instead of forcefully executing _preempt() across mismatched lengths or modalities, unmatched sequences are deferred to the next tick (deferred_running). This natively leverages PyTorch/Candle's ability to handle jagged context sizes without dropping generation caches or triggering PyTorch index-select bounds mismatches (which can occur in M-RoPE models like Qwen2-VL when modalities are improperly mixed).

Benchmarks & Stability Proofs

To confirm the fix and ensure no regressions in mixed-modality workloads, tests were run on an L4 GPU (CUDA 12.4.1) using Qwen/Qwen2.5-0.5B-Instruct and Qwen/Qwen2-VL-2B-Instruct with --prefix-cache-n 10 enabled.

Case 1: The O(N^2) Allocator Thrashing (Issue #2024)

Testing with 6 concurrent, heavy text-generation requests of intentionally misaligned lengths:

  • Before: The engine preempted actively decoding sequences, causing infinite prefill loops. Throughput bottlenecked at ~2.90 tokens/sec.
  • After: Preemptions are bypassed for active completions. Throughput recovered to ~97.43 tokens/sec natively on the L4. (~33x Speedup).

Before Log:

Starting 6 concurrent heavy requests...
[req1] Finished in 101.44s, 0.49 tok/s (Tokens: 50)
[req2] Finished in 101.83s, 0.49 tok/s (Tokens: 50)
[req4] Finished in 102.77s, 0.49 tok/s (Tokens: 50)
[req3] Finished in 102.77s, 0.49 tok/s (Tokens: 50)
[req5] Finished in 103.18s, 0.48 tok/s (Tokens: 50)
[req6] Finished in 103.62s, 0.48 tok/s (Tokens: 50)
Total time: 103.62s
Overall throughput: 2.90 tok/s

After Log:

Starting 6 concurrent heavy requests...
[req3] Finished in 3.08s, 16.25 tok/s (Tokens: 50)
[req4] Finished in 3.08s, 16.25 tok/s (Tokens: 50)
[req1] Finished in 3.08s, 16.24 tok/s (Tokens: 50)
[req5] Finished in 3.08s, 16.25 tok/s (Tokens: 50)
[req2] Finished in 3.08s, 16.24 tok/s (Tokens: 50)
[req6] Finished in 3.08s, 16.25 tok/s (Tokens: 50)
Total time: 3.08s
Overall throughput: 97.43 tok/s

Case 2: Multi-Modal Stability

Testing an asynchronous VLM execution (has_images=true) alongside disparate Prefix-Cache Text requests (has_images=false).

Before (Unpatched master) Bounds Panic:
Note: Without the subset deferrals, batching Vision payloads alongside text lengths instantly triggered a CUDA assert panic across the M-RoPE boundary.

mistralrs_core::engine::logger: Throughput (T/s) 11.20, Prefix cache hitrate 0.00%, 2 running, 1 waiting
mistralrs_core::engine::logger: Throughput (T/s) 8.10, Prefix cache hitrate 0.00%, 2 running, 1 waiting
mistralrs_core::engine: completion step - Model failed with error: DriverError(CUDA_ERROR_ASSERT, "device-side assert triggered")

# Python Client Output
Starting concurrent ALL SIDE EFFECTS test...
[VISION_REQ] API ERROR: {'message': 'DriverError(CUDA_ERROR_ASSERT, "device-side assert triggered")', 'partial_response': {'id': '1', 'choices': [{'finish_reason': 'error', 'index': 0, 'message': {'content': 'Red', ...}}]}}
[TEXT_REQ_2] API ERROR: {'message': 'No response received from the model.'}
[TEXT_REQ_1] API ERROR: {'message': 'No response received from the model.'}
Total Time: 2.36s

After (Patched fix-issue-2024-pa-allocator) Native Isolation:
Note: The unbucketing isolation routes modalities safely without preempting caches. The pipeline generated concurrent inferences successfully.

mistralrs_core::engine::logger: Throughput (T/s) 15.60, Prefix cache hitrate 0.00%, 2 running, 1 waiting
mistralrs_core::engine::logger: Throughput (T/s)  8.60, Prefix cache hitrate 0.00%, 2 running, 1 waiting

# Python Client Output
Starting concurrent ALL SIDE EFFECTS test...
[VISION_REQ] Finished in 1.48s:
Red

[TEXT_REQ_1] Finished in 1.45s:
1.988 x 10^30 kg

[TEXT_REQ_2] Finished in 1.47s:
7.35 x 10^22 kg

Total Time: 1.50s
Case 2 Python Test Script
import asyncio
import aiohttp
import time

b64_image = "iVBORw0KGgoAAAANSUhEUgAAAEAAAABACAIAAAAlC+aJAAAAX0lEQVR4nO3PQQ0AIBDAMMC/50MEj4ZkVbDtWX87OuBVA1oDWgNaA1oDWgNaA1oDWgNaA1oDWgNaA1oDWgNaA1oDWgNaA1oDWgNaA1oDWgNaA1oDWgNaA1oDWgNaA9oFUoUBf3Xr7AgAAAAASUVORK5CYII="

async def fetch(session, payload, idx, name):
    start = time.time()
    try:
        async with session.post(
            'http://localhost:1234/v1/chat/completions',
            json=payload,
            timeout=600
        ) as response:
            res = await response.json()
            if 'choices' not in res:
                print(f"[{name}] API ERROR: {res}")
                return 0
            text = res['choices'][0]['message']['content'].strip()
            elapsed = time.time() - start
            print(f'[{name}] Finished in {elapsed:.2f}s:\n{text}\n')
            return elapsed
    except Exception as e:
        print(f'[{name}] Exp Error: {e}')
        return 0

async def main():
    async with aiohttp.ClientSession() as session:
        vlm_payload = {
            'model': 'Qwen/Qwen2-VL-2B-Instruct',
            'messages': [{
                'role': 'user', 
                'content': [
                    {'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{b64_image}'}},
                    {'type': 'text', 'text': 'Describe the primary color present here strictly in one word.'}
                ]
            }],
            'max_tokens': 20,
            'temperature': 0.0
        }

        text_payload_1 = {
            'model': 'Qwen/Qwen2-VL-2B-Instruct',
            'messages': [{'role': 'user', 'content': 'Provide the exact numerical mass of the sun. Format as a single scientific notation number.'}],
            'max_tokens': 20,
            'temperature': 0.0
        }

        text_payload_2 = {
            'model': 'Qwen/Qwen2-VL-2B-Instruct',
            'messages': [{'role': 'user', 'content': 'Provide the exact numerical mass of the moon. Format as a single scientific notation number.'}],
            'max_tokens': 20,
            'temperature': 0.0
        }

        tasks = [
            fetch(session, vlm_payload, 1, "VISION_REQ"),
            fetch(session, text_payload_1, 2, "TEXT_REQ_1"),
            fetch(session, text_payload_2, 3, "TEXT_REQ_2")
        ]
        
        start_total = time.time()
        await asyncio.gather(*tasks)

if __name__ == '__main__':
    asyncio.run(main())

@glaziermag glaziermag force-pushed the fix-issue-2024-pa-allocator branch from ad04e75 to c3e142c Compare March 26, 2026 01:35
@glaziermag glaziermag marked this pull request as draft March 26, 2026 01:36
@glaziermag glaziermag force-pushed the fix-issue-2024-pa-allocator branch 3 times, most recently from 8339f6d to 0b25ba8 Compare March 26, 2026 02:06
@glaziermag glaziermag force-pushed the fix-issue-2024-pa-allocator branch from 0b25ba8 to 0773c31 Compare March 26, 2026 02:12
@glaziermag glaziermag marked this pull request as ready for review March 26, 2026 04:21
@glaziermag glaziermag changed the title Fix PagedAttention Scheduler O(N^2) Thrashing (Issue #2024) Fix PagedAttention Scheduler O(N^2) Thrashing Mar 26, 2026
emanuele-divizio-quixant added a commit to quixantplc/mistral.rs that referenced this pull request Mar 30, 2026
emanuele-divizio-quixant added a commit to quixantplc/mistral.rs that referenced this pull request Mar 30, 2026
emanueleDiVizio added a commit to emanueleDiVizio/mistral.rs that referenced this pull request Apr 2, 2026
…duler

Reapply upstream fixes from PRs EricLBuehler#2031/EricLBuehler#2034: fix quadratic scheduling
complexity when sequences are waiting, and add FCFS priority ordering
to prevent starvation.
emanueleDiVizio added a commit to emanueleDiVizio/mistral.rs that referenced this pull request Apr 2, 2026
…duler

Reapply upstream fixes from PRs EricLBuehler#2031/EricLBuehler#2034: fix quadratic scheduling
complexity when sequences are waiting, and add FCFS priority ordering
to prevent starvation.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

paged_attention: O(N²) memory allocator thrashing and continuous batching failure

1 participant