-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Re-indexing broken on GB test code #761
Comments
reproducible via
(with latest reactant and upcoming jll) |
ERROR: LoadError: ArgumentError: cannot re-index SubArray with fewer indices than dimensions
This should not occur; please submit a bug report.
Stacktrace:
[1] macro expansion
@ ./subarray.jl:305 [inlined]
[2] reindex(idxs::Tuple{Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, subidxs::Tuple{})
@ Base ./subarray.jl:305
[3] reindex
@ ./subarray.jl:294 [inlined]
[4] get_ancestor_indices(x::SubArray{Reactant.TracedRNumber{Float64}, 3, Reactant.TracedRArray{Float64, 3}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, false}, indices::CartesianIndex{3})
@ Reactant.TracedUtils ~/reactant/Reactant.jl/src/TracedUtils.jl:115
[5] getindex
@ ~/reactant/Reactant.jl/src/TracedRArray.jl:233 [inlined]
[6] getindex(none::SubArray{Reactant.TracedRNumber{Float64}, 3, Reactant.TracedRArray{Float64, 3}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, false}, none::Tuple{CartesianIndex{3}})
@ Reactant ./<missing>:0
[7] getindex
@ ~/reactant/Reactant.jl/src/TracedRArray.jl:233 [inlined]
[8] call_with_reactant(::Reactant.MustThrowError, ::typeof(getindex), ::SubArray{Reactant.TracedRNumber{Float64}, 3, Reactant.TracedRArray{Float64, 3}, Tuple{Base.Slice{…}, Base.Slice{…}, UnitRange{…}}, false}, ::CartesianIndex{3})
@ Reactant ~/reactant/Reactant.jl/src/utils.jl:0
[9] iterate
@ ./abstractarray.jl:1209 [inlined]
[10] iterate
@ ./abstractarray.jl:1207 [inlined]
[11] copyto_unaliased!
@ ./abstractarray.jl:1086 [inlined]
[12] copyto!
@ ./abstractarray.jl:1061 [inlined]
[13] copyto_axcheck!
@ ./abstractarray.jl:1167 [inlined]
[14] copyto_axcheck!(none::Array{Float32, 3}, none::SubArray{Reactant.TracedRNumber{Float64}, 3, Reactant.TracedRArray{Float64, 3}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, false})
@ Reactant ./<missing>:0
[15] size
@ ./array.jl:194 [inlined]
[16] axes
@ ./abstractarray.jl:98 [inlined]
[17] copyto_axcheck!
@ ./abstractarray.jl:1166 [inlined]
[18] call_with_reactant(::typeof(Base.copyto_axcheck!), ::Array{Float32, 3}, ::SubArray{Reactant.TracedRNumber{Float64}, 3, Reactant.TracedRArray{Float64, 3}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, false})
@ Reactant ~/reactant/Reactant.jl/src/utils.jl:0
[19] Array
@ ./array.jl:626 [inlined]
[20] Array{Float32, 3}(none::SubArray{Reactant.TracedRNumber{Float64}, 3, Reactant.TracedRArray{Float64, 3}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, false})
@ Reactant ./<missing>:0
[21] getproperty
@ ./Base.jl:49 [inlined]
[22] axes
@ ./subarray.jl:504 [inlined]
[23] size
@ ./subarray.jl:65 [inlined]
[24] Array
@ ./array.jl:626 [inlined]
[25] call_with_reactant(::Type{Array{Float32, 3}}, ::SubArray{Reactant.TracedRNumber{Float64}, 3, Reactant.TracedRArray{Float64, 3}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, false})
@ Reactant ~/reactant/Reactant.jl/src/utils.jl:0
[26] Array
@ ./boot.jl:606 [inlined]
[27] convert
@ ./array.jl:618 [inlined]
[28] convert(none::Type{Array{Float32}}, none::SubArray{Reactant.TracedRNumber{Float64}, 3, Reactant.TracedRArray{Float64, 3}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, false})
@ Reactant ./<missing>:0
[29] call_with_reactant(::typeof(convert), ::Type{Array{Float32}}, ::SubArray{Reactant.TracedRNumber{Float64}, 3, Reactant.TracedRArray{Float64, 3}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, false})
@ Reactant ~/reactant/Reactant.jl/src/utils.jl:423
[30] convert_output
@ ~/.julia/packages/Oceananigans/7k5MN/src/OutputWriters/fetch_output.jl:35 [inlined]
[31] convert_output(none::SubArray{…}, none::JLD2OutputWriter{…})
@ Reactant ./<missing>:0
[32] call_with_reactant(::typeof(Oceananigans.OutputWriters.convert_output), ::SubArray{…}, ::JLD2OutputWriter{…})
@ Reactant ~/reactant/Reactant.jl/src/utils.jl:423
[33] fetch_and_convert_output
@ ~/.julia/packages/Oceananigans/7k5MN/src/OutputWriters/fetch_output.jl:47 [inlined]
[34] #38
@ ./none:0 [inlined]
[35] (::Oceananigans.OutputWriters.var"#38#39"{JLD2OutputWriter{…}, OceanSeaIceModel{…}})(none::Tuple{Symbol, Field{…}})
@ Reactant ./<missing>:0
[36] indexed_iterate (repeats 2 times)
@ ./tuple.jl:159 [inlined]
[37] #38
@ ./none:0 [inlined]
[38] call_with_reactant(::Oceananigans.OutputWriters.var"#38#39"{JLD2OutputWriter{…}, OceanSeaIceModel{…}}, ::Tuple{Symbol, Field{…}})
@ Reactant ~/reactant/Reactant.jl/src/utils.jl:0
[39] iterate
@ ./generator.jl:48 [inlined]
[40] merge
@ ./namedtuple.jl:371 [inlined]
[41] merge(none::@NamedTuple{}, none::Base.Generator{Base.Iterators.Zip{Tuple{…}}, Oceananigans.OutputWriters.var"#38#39"{JLD2OutputWriter{…}, OceanSeaIceModel{…}}})
@ Reactant ./<missing>:0
[42] GenericMemory
@ ./boot.jl:514 [inlined]
[43] Array
@ ./boot.jl:578 [inlined]
[44] getindex
@ ./array.jl:400 [inlined]
[45] merge
@ ./namedtuple.jl:368 [inlined]
[46] call_with_reactant(::typeof(merge), ::@NamedTuple{}, ::Base.Generator{Base.Iterators.Zip{Tuple{…}}, Oceananigans.OutputWriters.var"#38#39"{JLD2OutputWriter{…}, OceanSeaIceModel{…}}})
@ Reactant ~/reactant/Reactant.jl/src/utils.jl:0
[47] NamedTuple
@ ./namedtuple.jl:151 [inlined]
[48] macro expansion
@ ./timing.jl:421 [inlined]
[49] write_output!
@ ~/.julia/packages/Oceananigans/7k5MN/src/OutputWriters/jld2_output_writer.jl:255 [inlined]
[50] write_output!(none::JLD2OutputWriter{…}, none::OceanSeaIceModel{…})
@ Reactant ./<missing>:0
[51] getproperty
@ ./Base.jl:49 [inlined]
[52] write_output!
@ ~/.julia/packages/Oceananigans/7k5MN/src/OutputWriters/jld2_output_writer.jl:236 [inlined]
[53] call_with_reactant(::typeof(Oceananigans.write_output!), ::JLD2OutputWriter{…}, ::OceanSeaIceModel{…})
@ Reactant ~/reactant/Reactant.jl/src/utils.jl:0
[54] initialize!
@ ~/.julia/packages/Oceananigans/7k5MN/src/Simulations/run.jl:231 [inlined]
[55] initialize!(none::Simulation{OceanSeaIceModel{…}, Float64, Float64, OrderedCollections.OrderedDict{…}, OrderedCollections.OrderedDict{…}, OrderedCollections.OrderedDict{…}})
@ Reactant ./<missing>:0
[56] getproperty
@ ./Base.jl:49 [inlined]
[57] initialize!
@ ~/.julia/packages/Oceananigans/7k5MN/src/Simulations/run.jl:194 [inlined]
[58] call_with_reactant(::typeof(Oceananigans.initialize!), ::Simulation{OceanSeaIceModel{…}, Float64, Float64, OrderedCollections.OrderedDict{…}, OrderedCollections.OrderedDict{…}, OrderedCollections.OrderedDict{…}})
@ Reactant ~/reactant/Reactant.jl/src/utils.jl:0
[59] time_step!
@ ~/.julia/packages/Oceananigans/7k5MN/src/Simulations/run.jl:118 [inlined]
[60] time_step!(none::Simulation{OceanSeaIceModel{…}, Float64, Float64, OrderedCollections.OrderedDict{…}, OrderedCollections.OrderedDict{…}, OrderedCollections.OrderedDict{…}})
@ Reactant ./<missing>:0
[61] time_ns
@ ./Base.jl:156 [inlined]
[62] time_step!
@ ~/.julia/packages/Oceananigans/7k5MN/src/Simulations/run.jl:113 [inlined]
[63] call_with_reactant(::typeof(time_step!), ::Simulation{OceanSeaIceModel{…}, Float64, Float64, OrderedCollections.OrderedDict{…}, OrderedCollections.OrderedDict{…}, OrderedCollections.OrderedDict{…}})
@ Reactant ~/reactant/Reactant.jl/src/utils.jl:0
[64] #run!#7
@ ~/.julia/packages/Oceananigans/7k5MN/src/Simulations/run.jl:102 [inlined]
[65] var"#run!#7"(none::Bool, none::typeof(run!), none::Simulation{OceanSeaIceModel{…}, Float64, Float64, OrderedCollections.OrderedDict{…}, OrderedCollections.OrderedDict{…}, OrderedCollections.OrderedDict{…}})
@ Reactant ./<missing>:0
[66] #run!#7
@ ~/.julia/packages/Oceananigans/7k5MN/src/Simulations/run.jl:93 [inlined]
[67] call_with_reactant(::Oceananigans.Simulations.var"##run!#7", ::Bool, ::typeof(run!), ::Simulation{OceanSeaIceModel{…}, Float64, Float64, OrderedCollections.OrderedDict{…}, OrderedCollections.OrderedDict{…}, OrderedCollections.OrderedDict{…}})
@ Reactant ~/reactant/Reactant.jl/src/utils.jl:0
[68] run!
@ ~/.julia/packages/Oceananigans/7k5MN/src/Simulations/run.jl:91 [inlined]
[69] run!(none::Simulation{OceanSeaIceModel{…}, Float64, Float64, OrderedCollections.OrderedDict{…}, OrderedCollections.OrderedDict{…}, OrderedCollections.OrderedDict{…}})
@ Reactant ./<missing>:0
[70] run!
@ ~/.julia/packages/Oceananigans/7k5MN/src/Simulations/run.jl:91 [inlined]
[71] call_with_reactant(::typeof(run!), ::Simulation{OceanSeaIceModel{…}, Float64, Float64, OrderedCollections.OrderedDict{…}, OrderedCollections.OrderedDict{…}, OrderedCollections.OrderedDict{…}})
@ Reactant ~/reactant/Reactant.jl/src/utils.jl:0
[72] make_mlir_fn(f::Function, args::Tuple{Simulation{…}}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool)
@ Reactant.TracedUtils ~/reactant/Reactant.jl/src/TracedUtils.jl:261
[73] make_mlir_fn
@ ~/reactant/Reactant.jl/src/TracedUtils.jl:153 [inlined]
[74] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{Simulation{…}}, callcache::Dict{Vector, @NamedTuple{…}}; optimize::Bool, no_nan::Bool, backend::String)
@ Reactant.Compiler ~/reactant/Reactant.jl/src/Compiler.jl:600
[75] compile_mlir!
@ ~/reactant/Reactant.jl/src/Compiler.jl:575 [inlined]
[76] (::Reactant.Compiler.var"#7#8"{@Kwargs{no_nan::Bool, optimize::Bool}, typeof(run!), Tuple{Simulation{OceanSeaIceModel{…}, Float64, Float64, OrderedCollections.OrderedDict{…}, OrderedCollections.OrderedDict{…}, OrderedCollections.OrderedDict{…}}}})()
@ Reactant.Compiler ~/reactant/Reactant.jl/src/Compiler.jl:494
[77] context!(f::Reactant.Compiler.var"#7#8"{@Kwargs{no_nan::Bool, optimize::Bool}, typeof(run!), Tuple{Simulation{…}}}, ctx::Reactant.MLIR.IR.Context)
@ Reactant.MLIR.IR ~/reactant/Reactant.jl/src/mlir/IR/Context.jl:76
[78] compile_mlir(f::Function, args::Tuple{Simulation{…}}; client::Nothing, kwargs::@Kwargs{no_nan::Bool, optimize::Bool})
@ Reactant.Compiler ~/reactant/Reactant.jl/src/Compiler.jl:491
[79] top-level scope
@ ~/reactant/Reactant.jl/src/Compiler.jl:1010
[80] include(fname::String)
@ Main ./sysimg.jl:38
in expression starting at /home/avik-pal/reactant/GB-25/ocean-climate-simulation/ocean_climate_simulation_mlir.jl:7
Some type information was truncated. Use `show(err)` to see complete types. |
julia> using Reactant
Precompiling Reactant...
1 dependency successfully precompiled in 8 seconds. 62 already precompiled.
AssertionError("Could not find registered platform with name: \"cuda\". Available platform names are: ")
julia> function fn(x)
y = view(x, Base.OneTo(3), Base.OneTo(3), 1:4)
return Array{Float32, 3}(y)
end
fn (generic function with 1 method)
julia> x = rand(4, 4, 4) |> Reactant.to_rarray
4×4×4 ConcreteRArray{Float64, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}:
[:, :, 1] =
0.680742 0.521762 0.00914581 0.314749
0.0177183 0.172033 0.134529 0.721138
0.33056 0.791873 0.332839 0.863476
0.0156262 0.141837 0.927453 0.804502
[:, :, 2] =
0.129421 0.480832 0.275151 0.402478
0.114712 0.00906093 0.622196 0.274481
0.423662 0.514893 0.880675 0.0368986
0.00839975 0.55047 0.270507 0.0301572
[:, :, 3] =
0.803851 0.734472 0.79049 0.556891
0.0205181 0.395981 0.691824 0.602781
0.595743 0.568166 0.636232 0.0494316
0.339392 0.469064 0.427616 0.962649
[:, :, 4] =
0.412271 0.713806 0.915033 0.987295
0.294725 0.377429 0.38638 0.998421
0.34498 0.750536 0.167738 0.447552
0.134442 0.860726 0.802057 0.661143
julia> @code_hlo fn(x)
ERROR: ArgumentError: cannot re-index SubArray with fewer indices than dimensions
This should not occur; please submit a bug report.
Stacktrace:
[1] macro expansion
@ ./subarray.jl:305 [inlined]
[2] reindex(idxs::Tuple{Base.OneTo{Int64}, UnitRange{Int64}}, subidxs::Tuple{})
@ Base ./subarray.jl:305
[3] reindex
@ ./subarray.jl:298 [inlined]
[4] get_ancestor_indices(x::SubArray{Reactant.TracedRNumber{Float64}, 3, Reactant.TracedRArray{Float64, 3}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, UnitRange{Int64}}, false}, indices::CartesianIndex{3})
@ Reactant.TracedUtils ~/reactant/Reactant.jl/src/TracedUtils.jl:115
[5] getindex
@ ~/reactant/Reactant.jl/src/TracedRArray.jl:233 [inlined]
[6] getindex(none::SubArray{Reactant.TracedRNumber{Float64}, 3, Reactant.TracedRArray{Float64, 3}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, UnitRange{Int64}}, false}, none::Tuple{CartesianIndex{3}})
@ Reactant ./<missing>:0
[7] getindex
@ ~/reactant/Reactant.jl/src/TracedRArray.jl:233 [inlined]
[8] call_with_reactant(::Reactant.MustThrowError, ::typeof(getindex), ::SubArray{Reactant.TracedRNumber{Float64}, 3, Reactant.TracedRArray{Float64, 3}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, UnitRange{Int64}}, false}, ::CartesianIndex{3})
@ Reactant ~/reactant/Reactant.jl/src/utils.jl:0
[9] iterate
@ ./abstractarray.jl:1209 [inlined]
[10] iterate
@ ./abstractarray.jl:1207 [inlined]
[11] copyto_unaliased!
@ ./abstractarray.jl:1086 [inlined]
[12] copyto!
@ ./abstractarray.jl:1061 [inlined]
[13] copyto_axcheck!
@ ./abstractarray.jl:1167 [inlined]
[14] copyto_axcheck!(none::Array{Float32, 3}, none::SubArray{Reactant.TracedRNumber{Float64}, 3, Reactant.TracedRArray{Float64, 3}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, UnitRange{Int64}}, false})
@ Reactant ./<missing>:0
[15] size
@ ./array.jl:194 [inlined]
[16] axes
@ ./abstractarray.jl:98 [inlined]
[17] copyto_axcheck!
@ ./abstractarray.jl:1166 [inlined]
[18] call_with_reactant(::typeof(Base.copyto_axcheck!), ::Array{Float32, 3}, ::SubArray{Reactant.TracedRNumber{Float64}, 3, Reactant.TracedRArray{Float64, 3}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, UnitRange{Int64}}, false})
@ Reactant ~/reactant/Reactant.jl/src/utils.jl:0
[19] Array
@ ./array.jl:626 [inlined]
[20] fn
@ ./REPL[2]:3 [inlined]
[21] fn(none::Reactant.TracedRArray{Float64, 3})
@ Reactant ./<missing>:0
[22] view
@ ./subarray.jl:214 [inlined]
[23] fn
@ ./REPL[2]:2 [inlined]
[24] call_with_reactant(::typeof(fn), ::Reactant.TracedRArray{Float64, 3})
@ Reactant ~/reactant/Reactant.jl/src/utils.jl:0
[25] make_mlir_fn(f::Function, args::Tuple{ConcreteRArray{…}}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool)
@ Reactant.TracedUtils ~/reactant/Reactant.jl/src/TracedUtils.jl:261
[26] make_mlir_fn
@ ~/reactant/Reactant.jl/src/TracedUtils.jl:153 [inlined]
[27] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{ConcreteRArray{…}}, callcache::Dict{Vector, @NamedTuple{…}}; optimize::Bool, no_nan::Bool, backend::String)
@ Reactant.Compiler ~/reactant/Reactant.jl/src/Compiler.jl:600
[28] compile_mlir!
@ ~/reactant/Reactant.jl/src/Compiler.jl:575 [inlined]
[29] (::Reactant.Compiler.var"#7#8"{@Kwargs{no_nan::Bool, optimize::Bool}, typeof(fn), Tuple{ConcreteRArray{Float64, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}})()
@ Reactant.Compiler ~/reactant/Reactant.jl/src/Compiler.jl:494
[30] context!(f::Reactant.Compiler.var"#7#8"{@Kwargs{no_nan::Bool, optimize::Bool}, typeof(fn), Tuple{ConcreteRArray{Float64, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}, ctx::Reactant.MLIR.IR.Context)
@ Reactant.MLIR.IR ~/reactant/Reactant.jl/src/mlir/IR/Context.jl:76
[31] compile_mlir(f::Function, args::Tuple{ConcreteRArray{Float64, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}; client::Nothing, kwargs::@Kwargs{no_nan::Bool, optimize::Bool})
@ Reactant.Compiler ~/reactant/Reactant.jl/src/Compiler.jl:491
[32] top-level scope
@ ~/reactant/Reactant.jl/src/Compiler.jl:1010
Some type information was truncated. Use `show(err)` to see complete types. |
cc @glwagner . If Oceananigans is explicitly making an array that is also a bug (since it should pick the array type based off the backend). @avik-pal the error stems from https://github.com/CliMA/Oceananigans.jl/blob/76c5f290130558f58c28208fde526a15334d3716/src/OutputWriters/fetch_output.jl#L35 right? |
@glwagner I suppose at a hgiher level we should talk about how the output writer should work here. We either shouldn't write during run and do it explicitly between reactant called functions, or we add a custom call like for MPI. If possible, the first option would be preferred for GB |
that's right |
Hmm ok. For the GB run, we probably want to have output fairly frequently, eg every 6 hours, hopefully in a run that covers 30 days (adding up to 120 calls). For general science you need to support arbitrarily frequent output, too. What kind of workflow do you envision? |
Can we save |
A pattern like this might work for GB: save_frequency = 100 # iterations
ow = JLD2OutputWriter(model, outputs, kw...)
# do NOT add the ow to `simulation.output_writers`. Instead...
r_run! = @compile run!(simulation)
# Assuming we haven't run anything yet:
simulation.stop_iteration = 0
for saves = 1:120 # total iterations = 100 * 120
simulation.stop_iteration += save_frequency
r_run!(simulation)
# manually save output:
Oceananigans.OutputWriters.write_output!(ow, model)
end |
the problem is that we can't save in the inside of a compiled function atm |
Ok, well I think it is relatively easy to work around. I want to offer this in case it is helpful. I'm not sure we need the features of function many_time_steps!(model, dt, Nsteps)
update_state!(model) # initialize
for n = 1:Nsteps
time_step!(model, dt)
end
return nothing
end then we could use run! = @compile many_time_steps!(model, dt, Nsteps)
run!(model) |
i think that's fine (though we'll want to do @Traced for regardless there, and also if time=0 conditionally initialize) |
您好,我看到以后就回复您
|
The text was updated successfully, but these errors were encountered: