Skip to content

Commit

Permalink
[oneMKL] Fix gesvd! (#485)
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison authored Dec 3, 2024
1 parent af10c9f commit 0b3955a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 11 deletions.
23 changes: 12 additions & 11 deletions lib/mkl/wrappers_lapack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -304,30 +304,31 @@ for (bname, fname, elty, relty) in ((:onemklSgesvd_scratchpad_size, :onemklSgesv
jobvt::Char,
A::oneStridedMatrix{$elty})
m, n = size(A)
k = min(m, n)
lda = max(1, stride(A, 2))

U = if jobu === 'A'
oneMatrix{$elty}(undef, m, m)
elseif jobu == 'S' || jobu === 'O'
oneMatrix{$elty}(undef, m, min(m, n))
elseif jobu === 'N'
oneMatrix{$elty}(undef, 0, 0) # Equivalence of CU_NULL?
elseif jobu === 'S'
oneMatrix{$elty}(undef, m, k)
elseif jobu === 'N' || jobu === 'O'
ZE_NULL
else
error("jobu must be one of 'A', 'S', 'O', or 'N'")
end
ldu = U == oneMatrix{$elty}(undef, 0, 0) ? 1 : max(1, stride(U, 2))
S = oneVector{$relty}(undef, min(m, n))
ldu = U == ZE_NULL ? 1 : max(1, stride(U, 2))
S = oneVector{$relty}(undef, k)

Vt = if jobvt === 'A'
oneMatrix{$elty}(undef, n, n)
elseif jobvt === 'S' || jobvt === 'O'
oneMatrix{$elty}(undef, min(m, n), n)
elseif jobvt === 'N'
oneMatrix{$elty}(undef, 0, 0)
elseif jobvt === 'S'
oneMatrix{$elty}(undef, k, n)
elseif jobvt === 'N' || jobvt === 'O'
ZE_NULL
else
error("jobvt must be one of 'A', 'S', 'O', or 'N'")
end
ldvt = Vt == oneArray{$elty}(undef, 0, 0) ? 1 : max(1, stride(Vt, 2))
ldvt = Vt == ZE_NULL ? 1 : max(1, stride(Vt, 2))

queue = global_queue(context(A), device())
scratchpad_size = $bname(sycl_queue(queue), jobu, jobvt, m, n, lda, ldu, ldvt)
Expand Down
10 changes: 10 additions & 0 deletions test/onemkl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1421,6 +1421,16 @@ end
d_A = oneMatrix(A)
U, Σ, Vt = oneMKL.gesvd!('A', 'A', d_A)
@test A collect(U[:,1:n] * Diagonal(Σ) * Vt)

for jobu in ('A', 'S', 'N', 'O')
for jobvt in ('A', 'S', 'N', 'O')
(jobu == 'A') && (jobvt == 'A') && continue
(jobu == 'O') && (jobvt == 'O') && continue
d_A = oneMatrix(A)
U2, Σ2, Vt2 = oneMKL.gesvd!(jobu, jobvt, d_A)
@test Σ Σ2
end
end
end

@testset "syevd! -- heevd!" begin
Expand Down

2 comments on commit 0b3955a

@maleadt
Copy link
Member

@maleadt maleadt commented on 0b3955a Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

oneAPI.jl 2.0 is a relatively minor release, which the most important change being behind the scenes: GPUArrays.jl v11 has switched to KernelAbstractions.jl (#475).

There is one breaking change: all intrinsics (get_local_id, get_global_idx, etc) are fully 1-based now. They used to return 1-based values, while taking a 0-based dimension index. With the move to SPIRVIntrinsics.jl (#477) providing these intrinsics, the optional dimension index is also 1-based now.
Note that this is unlikely to, or rather should not affect many users, because in most cases the dimension index can be omitted in order to default to the first dimension.

Features

Bug fixes

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/122582

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v2.0.0 -m "<description of version>" 0b3955a78b3ba5f2d8982f871753a063f549e783
git push origin v2.0.0

Please sign in to comment.