forked from JuliaGPU/Metal.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrandom.jl
69 lines (59 loc) · 3.13 KB
/
random.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
using Random
using ..MPS: MPSVector, _mpsmat_rand!, MPSMatrixRandomUniformDistributionDescriptor,
MPSMatrixRandomNormalDistributionDescriptor
gpuarrays_rng() = GPUArrays.default_rng(MtlArray)
mpsrand_rng() = MPS.default_rng()
# GPUArrays in-place
Random.rand!(A::MtlArray) = Random.rand!(gpuarrays_rng(), A)
Random.randn!(A::MtlArray) = Random.randn!(gpuarrays_rng(), A)
@inline function can_use_mpsrandom(A::MtlArray{T}) where {T}
return A.offset * sizeof(T) % 4 == 0 && sizeof(A) % 4 == 0
end
# Use MPS random functionality where possible
function Random.rand!(A::MPS.UniformArray)
rng = can_use_mpsrandom(A) ? mpsrand_rng() : gpuarrays_rng()
return Random.rand!(rng, A)
end
function Random.randn!(A::MPS.NormalArray)
rng = can_use_mpsrandom(A) ? mpsrand_rng() : gpuarrays_rng()
return Random.randn!(rng, A)
end
# GPUArrays out-of-place
function rand(T::MPS.UniformType, dims::Dims; storage=DefaultStorageMode)
rng = prod(dims) * sizeof(T) % 4 == 0 ? mpsrand_rng() : gpuarrays_rng()
return Random.rand!(rng, MtlArray{T,length(dims),storage}(undef, dims...))
end
function randn(T::MPS.NormalType, dims::Dims; storage=DefaultStorageMode)
rng = prod(dims) * sizeof(T) % 4 == 0 ? mpsrand_rng() : gpuarrays_rng()
return Random.randn!(rng, MtlArray{T,length(dims),storage}(undef, dims...))
end
rand(T::Type, dims::Dims; storage=DefaultStorageMode) =
Random.rand!(gpuarrays_rng(), MtlArray{T,length(dims),storage}(undef, dims...))
randn(T::Type, dims::Dims; storage=DefaultStorageMode) =
Random.randn!(gpuarrays_rng(), MtlArray{T,length(dims),storage}(undef, dims...))
# support all dimension specifications
function rand(T::MPS.UniformType, dim1::Integer, dims::Integer...; storage=DefaultStorageMode)
rng = (dim1 * prod(dims) * sizeof(T)) % 4 == 0 ? mpsrand_rng() : gpuarrays_rng()
return Random.rand!(rng, MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))
end
function randn(T::MPS.NormalType, dim1::Integer, dims::Integer...; storage=DefaultStorageMode)
rng = (dim1 * prod(dims) * sizeof(T)) % 4 == 0 ? mpsrand_rng() : gpuarrays_rng()
return Random.randn!(rng, MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))
end
rand(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
Random.rand!(gpuarrays_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))
randn(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
Random.randn!(gpuarrays_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))
# untyped out-of-place
rand(dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
Random.rand!(mpsrand_rng(), MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...))
randn(dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
Random.randn!(mpsrand_rng(), MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...))
# scalars
rand(T::Type=Float32; storage=SharedStorage) = rand(T, 4; storage)[1]
randn(T::Type=Float32; storage=SharedStorage) = randn(T, 4; storage)[1]
# seeding
function seed!(seed=Base.rand(UInt64))
Random.seed!(gpuarrays_rng(), seed)
Random.seed!(mpsrand_rng(), seed)
end