Skip to content

[fal.ai] longlive/DenoiseBlock: KV cache write_start_index == local_end_index — empty tensor slice causes 310+ chunk errors per session #921

@livepeer-tessa

Description

@livepeer-tessa

Summary

The longlive pipeline is producing ~310 ERROR-level chunk failures per session due to an invalid KV cache index range where write_start_index == local_end_index, creating an empty slice assignment that PyTorch rejects. The error fires continuously during active streaming, triggered by rapid mode transitions (video↔text).

cc @mjh1 @emranemran

Error Messages

Error in block: (denoise, DenoiseBlock)
Error details: The expanded size of the tensor (0) must match the existing size (3072) at non-singleton dimension 1.  Target sizes: [1, 0, 12, 128].  Tensor sizes: [3072, 12, 128]
scope.server.pipeline_processor - ERROR - [067a55be] Error processing chunk for longlive: The expanded size of the tensor (0) must match the existing size (3072) at non-singleton dimension 1.  Target sizes: [1, 0, 12, 128].  Tensor sizes: [3072, 12, 128]

Stack Trace

File "/app/src/scope/server/pipeline_processor.py", line 475, in process_chunk
    output_dict = self.pipeline(**call_params)
File "/app/src/scope/core/pipelines/longlive/pipeline.py", line 209, in __call__
    return self._generate(**kwargs)
File "/app/src/scope/core/pipelines/longlive/pipeline.py", line 250, in _generate
    _, self.state = self.blocks(self.components, self.state)
File "/app/.venv/lib/python3.12/site-packages/diffusers/modular_pipelines/modular_pipeline.py", line 932, in __call__
    pipeline, state = block(pipeline, state)
File "/app/src/scope/core/pipelines/wan2_1/blocks/denoise.py", line 185, in __call__
    _, denoised_pred = components.generator(...)
File "/app/src/scope/core/pipelines/wan2_1/vace/models/causal_vace_model.py", line 500, in _forward_inference
    result = block(x, **kwargs)
File "/app/src/scope/core/pipelines/wan2_1/vace/models/attention_blocks.py", line 157, in forward
    result = super().forward(...)
File "/app/src/scope/core/pipelines/longlive/modules/causal_model.py", line 508, in forward
    self_attn_result = self.self_attn(...)
File "/app/src/scope/core/pipelines/longlive/modules/causal_model.py", line 328, in forward
    temp_k[:, write_start_index:local_end_index] = roped_key[...]
RuntimeError: The expanded size of the tensor (0) must match the existing size (3072) at non-singleton dimension 1.  Target sizes: [1, 0, 12, 128].  Tensor sizes: [3072, 12, 128]

Root Cause Analysis

File: src/scope/core/pipelines/longlive/modules/causal_model.py, line 328

The issue is in the self-attention KV cache write:

temp_k[:, write_start_index:local_end_index] = roped_key[...]

When write_start_index == local_end_index, the slice [:, 0:0, ...] produces shape [1, 0, 12, 128], but roped_key still has sequence length 3072, so the assignment fails.

This likely happens after rapid mode transitions (video→text→video) that cause the cache state to be in an inconsistent position — the indices are computed from cache state that wasn't fully reset between transitions.

Frequency (last 12h, 2026-04-12 06:09 – 18:09 UTC)

  • ~310 total occurrences across session 067a55be
  • Fires continuously throughout session at ~1–2 Hz during streaming
  • Time window: 14:16–14:55 UTC
  • App: github_f1lhgmk5v76a0ev1w0u378by-scope-app--prod

Reproduction Context

Pipeline was loaded with VACE enabled (vace_enabled: True) and rapid mode transitions were occurring:

14:16:29 - handle_mode_transition: Mode changed from text to video, resetting cache
14:16:30 - HARD CUT: Executing cache reset
14:16:31 - handle_mode_transition: Mode changed from video to text, resetting cache
14:16:32 - handle_mode_transition: Mode changed from text to video, resetting cache
14:16:34 - handle_mode_transition: Mode changed from video to text, resetting cache
14:16:35 - handle_mode_transition: Mode changed from text to video, resetting cache
14:16:38 ❌ ERROR: DenoiseBlock tensor size 0

Suggested Fix

In causal_model.py around line 328, add a guard:

if write_start_index < local_end_index:
    temp_k[:, write_start_index:local_end_index] = roped_key[...]
    temp_v[:, write_start_index:local_end_index] = roped_value[...]
else:
    logger.warning(f'KV cache write skipped: write_start_index ({write_start_index}) >= local_end_index ({local_end_index})')

Additionally, investigate cache state consistency after mode transitions — index computation should be re-derived from the reset state to avoid stale values.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions