feat: add varlen attention on cpu#777
Conversation
kozistr
left a comment
There was a problem hiding this comment.
aside from one thing, everything else looks good to me!
| let causal_mask = create_causal_mask_batch(seq_len_q, seq_len_k, num_heads, device)?; | ||
| attention_scores = attention_scores.add(&causal_mask)?; |
There was a problem hiding this comment.
i was just wondering that it looks like causual_mask and window_mask below are always fp32 type while attention_scores could be fp16. I'm not sure if I'm right, it might fail due to a type mismatch!
There was a problem hiding this comment.
Thanks, you are right. i figured that out too.
There was a problem hiding this comment.
In this case its always fp32, but for apple metal backend it could be fp16 afaik.
|
I am going to mark this PR as draft. I implemented a pretty fast attention primitive here: huggingface/candle#3250 once that is merged (which i am eagerly waiting for) we can do a simple copy of the function here (without the tests). |
What does this PR do?
This PR brings varlen-flash-attention to CPU/Metal. (Its not softmax-fused / Flash), but at least its not-padded, so we don't do OOM.
Fixes # (issue)
Before submitting
instasnapshots?Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.