Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,21 @@ StridedViews = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143"
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
cuBLAS = "182d3088-87b7-4494-8cad-fc6afaa545bc"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"

[extensions]
StridedcuBLASExt = "cuBLAS"
StridedGPUArraysExt = "GPUArrays"
StridedAMDGPUExt = "AMDGPU"

[compat]
AMDGPU = "2"
Aqua = "0.8"
Adapt = "4"
CUDACore = "6"
cuBLAS = "6"
cuRAND = "6"
GPUArrays = "11.4.1"
JLArrays = "0.3.1"
Expand All @@ -35,6 +40,7 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CUDACore = "bd0ed864-bdfe-4181-a5ed-ce625a5fdea2"
cuBLAS = "182d3088-87b7-4494-8cad-fc6afaa545bc"
cuRAND = "20fd9a0b-12d5-4c2f-a8af-7c34e9e60431"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
Expand All @@ -43,4 +49,4 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "Random", "Aqua", "AMDGPU", "CUDACore", "cuRAND", "GPUArrays", "JLArrays", "Metal", "Adapt"]
test = ["Test", "Random", "Aqua", "AMDGPU", "CUDACore", "cuBLAS", "cuRAND", "GPUArrays", "JLArrays", "Metal", "Adapt"]
19 changes: 19 additions & 0 deletions ext/StridedAMDGPUExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
module StridedAMDGPUExt

using Strided, StridedViews, AMDGPU, AMDGPU.rocBLAS, LinearAlgebra
import Strided: blas_mul!

const ROCStridedView{T, N, A <: ROCArray{T}} = StridedViews.StridedView{T, N, A}

function Strided.blas_mul!(C::ROCStridedView{T, 2}, A::ROCStridedView{T, 2}, B::ROCStridedView{T, 2}, α::Number, β::Number) where {T <: LinearAlgebra.BlasFloat}
A2, CA = Strided.getblasmatrix(A)
B2, CB = Strided.getblasmatrix(B)
C2, CC = Strided.getblasmatrix(C)
A2a = Base.unsafe_wrap(ROCMatrix{T}, pointer(A2), size(A2))
B2a = Base.unsafe_wrap(ROCMatrix{T}, pointer(B2), size(B2))
C2a = Base.unsafe_wrap(ROCMatrix{T}, pointer(C2), size(C2))
Comment on lines +12 to +14

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just checked the actual blas wrapper implementation, do you think it would be worth it to simply call that directly, so we can actually deal with the cases where the stride is not equal to the size? (https://github.com/JuliaGPU/AMDGPU.jl/blob/f49923a4c13b06325ff32696952b34d5ec73998f/src/blas/wrappers.jl#L547-L566)

Alternatively, we could also keep that for MatrixAlgebraKit and future implementations, I'm happy to already get this in as well.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered doing that but then we have to handle the library handle etc ourselves, which I'd prefer not to do. I'd say merge for now and we can revisit if needed.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So that this silently fail for matrices that do not have stride(A, 1) = size(A, 1) ?

Would it have been an option to do something like

A2a = view(Base.unsafe_wrap(ROCMatrix{T}, pointer(A2), (stride(A2, 1), size(A2, 2)), 1:size(A2, 1), :)

@Jutho Jutho Jun 9, 2026

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see this is at least checked in a specialized isblasmatrix version.

AMDGPU.rocBLAS.gemm!(CA, CB, convert(T, α), A2a, B2a, convert(T, β), C2a)
return C
end

end
16 changes: 0 additions & 16 deletions ext/StridedCUDACoreExt.jl

This file was deleted.

15 changes: 15 additions & 0 deletions ext/StridedGPUArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ using GPUArrays: Adapt, KernelAbstractions
using GPUArrays.KernelAbstractions: @kernel, @index
using StridedViews: ParentIndex

import Strided: isblasmatrix

ALL_FS = Union{typeof(adjoint), typeof(conj), typeof(identity), typeof(transpose)}

# StridedView backed by any GPU array type, with element type linked to the parent.
Expand Down Expand Up @@ -129,4 +131,17 @@ function Strided._mapreduce_block!(
return nothing
end

function Strided.isblasmatrix(A::GPUStridedView{T, 2}) where {T <: LinearAlgebra.BlasFloat}
if A.op == identity
# unsafe wrap approach doesn't work if second condition not met
return stride(A, 1) == 1 && size(A, 1) == size(parent(A), 1)
elseif A.op == conj
# this is converted to adjoint
# unsafe wrap approach doesn't work if second condition not met
return stride(A, 2) == 1 && size(A, 2) == size(parent(A), 2)
else # should never happen
return false
end
end

end
19 changes: 19 additions & 0 deletions ext/StridedcuBLASExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
module StridedcuBLASExt

using Strided, StridedViews, cuBLAS, cuBLAS.CUDACore, LinearAlgebra
import Strided: blas_mul!

const CuStridedView{T, N, A <: CuArray{T}} = StridedViews.StridedView{T, N, A}

function Strided.blas_mul!(C::CuStridedView{T, 2}, A::CuStridedView{T, 2}, B::CuStridedView{T, 2}, α::Number, β::Number) where {T <: LinearAlgebra.BlasFloat}
A2, CA = Strided.getblasmatrix(A)
B2, CB = Strided.getblasmatrix(B)
C2, CC = Strided.getblasmatrix(C)
A2a = Base.unsafe_wrap(CuMatrix{T}, pointer(A2), size(A2))
B2a = Base.unsafe_wrap(CuMatrix{T}, pointer(B2), size(B2))
C2a = Base.unsafe_wrap(CuMatrix{T}, pointer(C2), size(C2))
cuBLAS.gemm!(CA, CB, convert(T, α), A2a, B2a, convert(T, β), C2a)
return C
end

end
12 changes: 10 additions & 2 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,21 @@ function _mul!(
α::Number, β::Number
) where {T <: LinearAlgebra.BlasFloat}
if stride(C, 1) == 1 && isblasmatrix(A) && isblasmatrix(B)
nthreads = use_threaded_mul() ? get_num_threads() : 1
_threaded_blas_mul!(C, A, B, α, β, nthreads)
return blas_mul!(C, A, B, α, β)
else
return __mul!(C, A, B, α, β)
end
end

# for CPU based arrays, this is valid
function blas_mul!(
C::StridedView{T, 2}, A::StridedView{T, 2}, B::StridedView{T, 2},
α::Number, β::Number
) where {T <: LinearAlgebra.BlasFloat}
nthreads = use_threaded_mul() ? get_num_threads() : 1
return _threaded_blas_mul!(C, A, B, α, β, nthreads)
end

function _threaded_blas_mul!(
C::StridedView{T, 2}, A::StridedView{T, 2}, B::StridedView{T, 2},
α::Number, β::Number,
Expand Down
30 changes: 0 additions & 30 deletions test/cuda.jl

This file was deleted.

61 changes: 60 additions & 1 deletion test/gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,21 @@ end
# types to test for
ATs = []
!is_buildkite && push!(ATs, JLArray)
CUDACore.functional() && push!(ATs, CuArray)
CUDACore.functional() && cuBLAS.functional() && push!(ATs, CuArray)
AMDGPU.functional() && push!(ATs, ROCArray)
Metal.functional() && push!(ATs, MtlArray)

@testset "isblasmatrix ($AT)" for AT in ATs
for T in (Float32, ComplexF32)
A1 = StridedView(AT(randn(T, 20, 20)))
@test Strided.isblasmatrix(A1)
Comment thread
kshyatt marked this conversation as resolved.
A2 = view(A1, 1:4:20, 1:5:20)
@test !Strided.isblasmatrix(A2)
A3 = view(conj!(A1), 1:4:20, 1:20) # stride(A3, 2) is not 1
@test !Strided.isblasmatrix(A3)
end
end

@testset "in-place matrix operations ($AT)" for AT in ATs
for T in (Float32, ComplexF32)
A1 = StridedView(randn(T, 20, 20))
Expand All @@ -38,6 +49,38 @@ Metal.functional() && push!(ATs, MtlArray)
end
end

@testset "mul! ($AT{$T})" for AT in ATs, T in (Float32, ComplexF32)
Comment thread
kshyatt marked this conversation as resolved.
N = 2
α = rand(T)
β = rand(T)
dims = ntuple(Returns(div(64, N)), N)
A1 = permutedims(StridedView(rand(T, dims)), randperm(N))
A2 = permutedims(StridedView(rand(T, dims)), randperm(N))
A3 = permutedims(StridedView(rand(T, dims)), randperm(N))
@test compare((C, A, B) -> mul!(C, A, B, α, β), AT, A1, A2, A3)
# test BLAS for all op combinations
@testset for sz in ((32, 64), (64, 64), (64, 32))
vA1 = view(StridedView(rand(T, sz)), 1:32, 1:32)
vA2 = view(StridedView(rand(T, sz)), 1:32, 1:32)
vA3 = view(StridedView(rand(T, sz)), 1:32, 1:32)
@testset for f1 in (identity, conj, adjoint, transpose), f2 in (identity, conj, adjoint, transpose)
@test compare((C, A, B) -> mul!(C, A, B, α, β), AT, vA1, f1(vA2), f2(vA3))
end
end
# non-BLAS fallback path
vA1 = view(StridedView(rand(T, (32, 32))), 1:32, 1:32)
vA2 = view(StridedView(rand(T, (32, 64))), 1:32, 1:2:64)
vA3 = view(StridedView(rand(T, (64, 32))), 1:2:64, 1:32)
@testset for f1 in (identity, conj, adjoint, transpose), f2 in (identity, conj, adjoint, transpose)
@test compare((C, A, B) -> mul!(C, A, B, α, β), AT, vA1, f1(vA2), f2(vA3))
end
# non-BLAS fallback path
vA1 = view(StridedView(rand(T, (64, 32))), 1:2:64, 1:32)
vA2 = view(StridedView(rand(T, (32, 64))), 1:32, 1:2:64)
vA3 = view(StridedView(rand(T, (64, 32))), 1:2:64, 1:32)
@test compare((C, A, B) -> mul!(C, A, B, α, β), AT, vA1, vA2, vA3)
end

@testset "map, scale!, axpy!, axpby! ($AT)" for AT in ATs
for T in (Float32, ComplexF32)
for N in 2:6
Expand Down Expand Up @@ -69,6 +112,22 @@ end
end
end

@testset "copy ($AT)" for AT in ATs
N = 2
for m1 in (0, 16, 32), m2 in (0, 16, 32), T in (Float32, ComplexF32)
dims = (m1, m2)
A1 = StridedView(rand(T, dims))
A2 = StridedView(rand(T, dims))
A3 = StridedView(rand(T, dims))
for f2 in (identity, conj, adjoint, transpose), f1 in (identity, conj, transpose, adjoint)
axes(f1(A1)) == axes(f2(A2)) || continue
B1 = f1(copy(A1))
B2 = f2(copy(A2))
@test compare((x, y) -> copy!(y, x), AT, B1, B2)
end
end
end

@testset "broadcasting ($AT)" for AT in ATs
for T in (Float32, ComplexF32)
A0 = StridedView(rand(T, ()))
Expand Down
25 changes: 0 additions & 25 deletions test/jlarrays.jl

This file was deleted.

2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using Aqua
using Adapt, GPUArrays
using JLArrays
using AMDGPU
using CUDACore, cuRAND
using CUDACore, cuRAND, cuBLAS
using Metal

Random.seed!(1234)
Expand Down
Loading