Skip to content

Commit

Permalink
bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
rafaqz committed Feb 5, 2024
1 parent a4df672 commit f529d28
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 20 deletions.
26 changes: 8 additions & 18 deletions src/cached.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,45 +24,35 @@ end
Base.parent(A::CachedDiskArray) = A.parent
Base.size(A::CachedDiskArray) = size(parent(A))
# These could be more efficient with memory in some cases, but this is simple
readblock!(A::CachedDiskArray, data, I...) = _readblock_cached(A, data, I...)
readblock!(A::CachedDiskArray, data, I::AbstractVector...) = _readblock_cached(A, data, I...)
readblock!(A::CachedDiskArray, data, I...) = _readblock_cached!(A, data, I...)
readblock!(A::CachedDiskArray, data, I::AbstractVector...) = _readblock_cached!(A, data, I...)
# TODO we need to invalidate caches when we write
# writeblock!(A::CachedDiskArray, data, I...) = writeblock!(parent(A), data, I...)

haschunks(A::CachedDiskArray) = haschunks(parent(A))
eachchunk(A::CachedDiskArray) = eachchunk(parent(A))

function readblock!(A::CachedDiskArray{T,N}, data, I...) where {T,N}
function _readblock_cached!(A::CachedDiskArray{T,N}, data, I...) where {T,N}
chunks = eachchunk(A)
chunk_inds = findchunk.(chunks.chunks, I)

chunk_arrays = map(chunks[chunk_inds...]) do c
if haskey(A.cache, c)
A.cache[c]
else
chunk_data = Array{T,N}(undef, length.(I))
A.cache[c] = readblock!(A, chunk_data, I...)
chunk_data = Array{T,N}(undef, length.(I)...)
A.cache[c] = readblock!(parent(A), chunk_data, I...)
end
end
out = ConcatDiskArray(chunk_arrays)

out_chunks = ConcatDiskArray(chunk_arrays)
out_inds = map(i -> i .- first(i) + 1, I)
out_inds = map(i -> i .- first(i) .+ 1, I)

data .= view(out_chunks, out_inds...)
data .= view(out, out_inds...)

return data
end

function _readblock_cached(A, data, I...)
if haskey(A.cache, I)
data .= A.cache[I]
else
readblock!(parent(A), data, I...)
A.cache[I] = copy(data)
end
return data
end

"""
cache(A::AbstractArray; maxsize=1000)
Expand Down
4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -709,8 +709,8 @@ end
end

@testset "Cached arrays" begin
A = (1:3000) * (1:1200)'
ch = ChunkedDiskArray((1:3000) * (1:1200)', (20, 5))
A = (1:300) * (1:1200)'
ch = ChunkedDiskArray((1:3000) * (1:1200)', (128, 128))
ca = DiskArrays.CachedDiskArray(ch; maxsize=5)
# Read the original
@test sum(ca) == sum(ca)
Expand Down

0 comments on commit f529d28

Please sign in to comment.