Skip to content

Fix isblasmatrix for GPUArrays#59

Merged
lkdvos merged 11 commits into
mainfrom
ksh/gemm
Jun 8, 2026
Merged

Fix isblasmatrix for GPUArrays#59
lkdvos merged 11 commits into
mainfrom
ksh/gemm

Conversation

@kshyatt

@kshyatt kshyatt commented May 13, 2026

Copy link
Copy Markdown
Member

This will help us use the new support for generic GPUArray strided views in a way that bypasses some really awful ambiguity warnings.

@kshyatt kshyatt requested a review from lkdvos May 13, 2026 10:12
@github-actions

github-actions Bot commented May 13, 2026

Copy link
Copy Markdown

Your PR no longer requires formatting changes. Thank you for your contribution!

@lkdvos lkdvos left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't fully fix the issue I think, that code path shouldn't ever be reached by the GPU arrays since it is guarded by an isblasmatrix call that checks pointer(A) isa Ptr.

I think this really needs a more proper rewrite that dispatches to a gemm function that then indeed can determine the proper driver.
Note also that the current fallback is using the Strided machinery to manually write out the kernel, which is actually equivalent to what the generic_matmatmul! function does anyways

@kshyatt

kshyatt commented May 13, 2026

Copy link
Copy Markdown
Member Author

pointer(A) isa Ptr is true for ROCArrays, so this code path is definitely reached by GPUArrays because it was erroring

@kshyatt

kshyatt commented May 13, 2026

Copy link
Copy Markdown
Member Author

This doesn't fully fix the issue I think

It seems to have unblocked the v1 AMD stuff on TO 🤷 . But if someone wants to make a higher performance version, go ahead. The rocBLAS gemm doesn't work well if the stride in the first dimension isn't 1, I think.

@codecov

codecov Bot commented May 13, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 32.14286% with 19 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
ext/StridedAMDGPUExt.jl 0.00% 9 Missing ⚠️
ext/StridedcuBLASExt.jl 0.00% 9 Missing ⚠️
ext/StridedGPUArraysExt.jl 83.33% 1 Missing ⚠️
Files with missing lines Coverage Δ
src/linalg.jl 70.10% <100.00%> (+47.99%) ⬆️
ext/StridedGPUArraysExt.jl 49.25% <83.33%> (+3.35%) ⬆️
ext/StridedAMDGPUExt.jl 0.00% <0.00%> (ø)
ext/StridedcuBLASExt.jl 0.00% <0.00%> (ø)

... and 1 file with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@lkdvos

lkdvos commented May 13, 2026

Copy link
Copy Markdown
Member

Would it then not make more sense to fix the isblasmatrix implementation? That at least treats CUDA and AMD equally then.

@kshyatt

kshyatt commented May 13, 2026

Copy link
Copy Markdown
Member Author

So I think the CUDA one doesn't ever touch this because the result of pointer(A) there is a CuPtr, NOT a Ptr 🙃 . Not sure about Metal?

@lkdvos

lkdvos commented May 13, 2026

Copy link
Copy Markdown
Member

Yes, exactly, I think my argument is to either:
A. Make AMD also not touch this so it is treated the same as the other backends
B. Make every single one either dispatch to gemm if possible, and if not use the Strided implementation, similar to how the current CPU version works.

@kshyatt kshyatt changed the title Use a pass-through for gemm Fix isblasmatrix for GPUArrays Jun 2, 2026
@kshyatt

kshyatt commented Jun 2, 2026

Copy link
Copy Markdown
Member Author

Finally got back to this, looks like it should be ok now?

@kshyatt

kshyatt commented Jun 4, 2026

Copy link
Copy Markdown
Member Author

After a Zulip discussion, we decided to try inserting a new blas_mul! function we can override where appropriate for the GPU arrays, allowing us to pass to the efficient vendor libraries where we can. I've added extensions for this.

Katharine Hyatt and others added 2 commits June 4, 2026 16:31

@lkdvos lkdvos left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only have some comments about code coverage, otherwise looks very clean!

Comment thread test/gpu.jl
Comment thread test/gpu.jl
@kshyatt

kshyatt commented Jun 5, 2026

Copy link
Copy Markdown
Member Author

Added some more tests to trigger the fallback to the GPUArrays kernels

@lkdvos

lkdvos commented Jun 6, 2026

Copy link
Copy Markdown
Member

Can we just exhaustively do all the mul cases that trigger the BLAS path too? I don't think we're using the adjoint path right now for example

@lkdvos

lkdvos commented Jun 7, 2026

Copy link
Copy Markdown
Member

I apologize strongly for this very inefficient back-and-forth, I'll be at my desk again next week so then things should hopefully be better. I think now the non-BLAS codepaths are indeed checked, but I was really mostly talking about exhaustively testing the various BLAS paths with the different values of transposed, conjugated etc, since that seems like the more brittle part (basically replacing the randperm tests with a test for each perm).
I should find some time to do this tomorrow if it's not clear what I mean, I apologize to keep dragging this on.

@kshyatt

kshyatt commented Jun 7, 2026

Copy link
Copy Markdown
Member Author

It's cool, no worries. I can also test a BLAS one (with all 1 strides) for each f1, f2 combo?

@lkdvos

lkdvos commented Jun 7, 2026

Copy link
Copy Markdown
Member

Yeah, I think I was mostly worried about the transpose and adjoint versions, where the getblasmatrix would generate non-trivial 'T' or 'A' characters etc.

@kshyatt

kshyatt commented Jun 7, 2026

Copy link
Copy Markdown
Member Author

I'll try to get to it tonight or tomorrow :)

@kshyatt

kshyatt commented Jun 8, 2026

Copy link
Copy Markdown
Member Author

Good call as I ended up finding and fixing a bug

@lkdvos lkdvos left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the tests!
Overall looks good to me, left a small note about the gemm call for AMD, but otherwise good to go!

Comment thread ext/StridedAMDGPUExt.jl
Comment on lines +12 to +14
A2a = Base.unsafe_wrap(ROCMatrix{T}, pointer(A2), size(A2))
B2a = Base.unsafe_wrap(ROCMatrix{T}, pointer(B2), size(B2))
C2a = Base.unsafe_wrap(ROCMatrix{T}, pointer(C2), size(C2))

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just checked the actual blas wrapper implementation, do you think it would be worth it to simply call that directly, so we can actually deal with the cases where the stride is not equal to the size? (https://github.com/JuliaGPU/AMDGPU.jl/blob/f49923a4c13b06325ff32696952b34d5ec73998f/src/blas/wrappers.jl#L547-L566)

Alternatively, we could also keep that for MatrixAlgebraKit and future implementations, I'm happy to already get this in as well.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered doing that but then we have to handle the library handle etc ourselves, which I'd prefer not to do. I'd say merge for now and we can revisit if needed.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So that this silently fail for matrices that do not have stride(A, 1) = size(A, 1) ?

Would it have been an option to do something like

A2a = view(Base.unsafe_wrap(ROCMatrix{T}, pointer(A2), (stride(A2, 1), size(A2, 2)), 1:size(A2, 1), :)

@Jutho Jutho Jun 9, 2026

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see this is at least checked in a specialized isblasmatrix version.

@lkdvos lkdvos merged commit a3183b2 into main Jun 8, 2026
9 of 13 checks passed
@lkdvos lkdvos deleted the ksh/gemm branch June 8, 2026 14:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants