Skip to content

Commit dd8bf36

Browse files
committed
Fix Voxtral Metal streaming mask
1 parent 7724fd7 commit dd8bf36

2 files changed

Lines changed: 44 additions & 1 deletion

File tree

examples/models/voxtral_realtime/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1129,7 +1129,8 @@ def create_causal_mask(
11291129
return torch.where(
11301130
valid,
11311131
torch.zeros(1, dtype=dtype, device=start_pos.device),
1132-
torch.tensor(float("-inf"), dtype=dtype, device=start_pos.device),
1132+
# MPS SDPA can propagate NaNs from -inf additive masks in AOTI.
1133+
torch.tensor(-1e9, dtype=dtype, device=start_pos.device),
11331134
)
11341135

11351136

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
from types import ModuleType
9+
from unittest.mock import patch
10+
11+
import torch
12+
13+
with patch.dict(
14+
"sys.modules",
15+
{"executorch.extension.llm.custom_ops.custom_ops": ModuleType("custom_ops")},
16+
):
17+
from executorch.examples.models.voxtral_realtime.model import StandardRingKVCache
18+
19+
20+
class StandardRingKVCacheTest(unittest.TestCase):
21+
def test_additive_mask_uses_finite_negative_values(self):
22+
cache = StandardRingKVCache(window_size=4, n_heads=1, head_dim=2)
23+
24+
mask = cache.create_causal_mask(
25+
torch.tensor(0), seq_len=1, dtype=torch.bfloat16
26+
)
27+
28+
self.assertEqual(mask.dtype, torch.bfloat16)
29+
self.assertTrue(torch.isfinite(mask).all())
30+
self.assertEqual(mask[0, 0].item(), 0)
31+
self.assertLess(mask[0, 1].float().item(), -1e8)
32+
33+
def test_bool_mask_keeps_bool_dtype(self):
34+
cache = StandardRingKVCache(window_size=4, n_heads=1, head_dim=2)
35+
36+
mask = cache.create_causal_mask(torch.tensor(3), seq_len=2, bool_mask=True)
37+
38+
self.assertEqual(mask.dtype, torch.bool)
39+
40+
41+
if __name__ == "__main__":
42+
unittest.main()

0 commit comments

Comments
 (0)