From 269d3274482c25e1c9907e6ff184e2f3cbe21366 Mon Sep 17 00:00:00 2001 From: fjebaker Date: Thu, 23 Jan 2025 16:55:15 +0000 Subject: [PATCH] feat!: add Tag parameter to SpectralData This enormously simplifies derived datatypes, and removed the API forwarding macro that was unreliable to maintain. Follows from the Tagging idea introduced with BinnedData, and allows derived types to be implemented more easily. Updated OGIP, XMM and NuSTAR datatypes to use this tagging system instead of adding new types. Modified test cases to sync with changes. --- src/SpectralFitting.jl | 4 +- src/datasets/mission-specifics.jl | 87 +++----------------- src/datasets/ogipdataset.jl | 39 +++++---- src/datasets/spectraldata.jl | 105 +++++++----------------- test/fitting/test-sample-data.jl | 5 +- test/simulation/test-sample-data-sim.jl | 13 +-- 6 files changed, 67 insertions(+), 186 deletions(-) diff --git a/src/SpectralFitting.jl b/src/SpectralFitting.jl index 52f14dae..846c8c37 100644 --- a/src/SpectralFitting.jl +++ b/src/SpectralFitting.jl @@ -29,9 +29,7 @@ import Optimization using DocStringExtensions -# for future use: mission specific parsing -abstract type AbstractMission end -struct NoMission <: AbstractMission end +abstract type AbstractInstrument end abstract type AbstractStatistic end struct ChiSquared <: AbstractStatistic end diff --git a/src/datasets/mission-specifics.jl b/src/datasets/mission-specifics.jl index 07aad4cb..8b3ff34d 100644 --- a/src/datasets/mission-specifics.jl +++ b/src/datasets/mission-specifics.jl @@ -1,87 +1,24 @@ +struct NuSTAR <: AbstractInstrument end +struct XmmEPIC <: AbstractInstrument end -struct NuStarData{T,H} <: AbstractDataset - data::SpectralData{T} - paths::SpectralDataPaths - observation_id::String - exposure_id::String - object::String - header::H -end - -NuStarData(spec_path; rmf_matrix_index = 3, rmf_energy_index = 2, kwargs...) = NuStarData( - load_ogip_dataset( +function NuStarData(spec_path; rmf_matrix_index = 3, rmf_energy_index = 2, kwargs...) + OGIPDataset( spec_path; - rmf_matrix_index = rmf_matrix_index, + tag = NuSTAR(), rmf_energy_index = rmf_energy_index, + rmf_matrix_index = rmf_matrix_index, kwargs..., - )..., -) - -make_label(data::NuStarData) = data.observation_id - -@_forward_SpectralData_api NuStarData.data - -function Base.show(io::IO, @nospecialize(data::NuStarData{T})) where {T} - print(io, "NuStarData[obs_id=$(data.observation_id)]") -end - -function _printinfo(io, data::NuStarData{T}) where {T} - descr = """NuStarData: - . Object : $(data.object) - . Observation ID : $(data.observation_id) - . Exposure ID : $(data.exposure_id) - """ - print(io, descr) - _printinfo(io, data.data) + ) end - -abstract type AbstractXmmNewtonDevice end -struct XmmEPIC <: AbstractXmmNewtonDevice end - -struct XmmData{T,H,D} <: AbstractDataset - device::D - data::SpectralData{T} - paths::SpectralDataPaths - observation_id::String - exposure_id::String - object::String - header::H -end - -XmmData( - device::AbstractXmmNewtonDevice, - spec_path; - rmf_matrix_index = 2, - rmf_energy_index = 3, - kwargs..., -) = XmmData( - device, - load_ogip_dataset( +function XmmData(spec_path; rmf_matrix_index = 2, rmf_energy_index = 3, kwargs...) + OGIPDataset( spec_path; - rmf_matrix_index = rmf_matrix_index, + tag = XmmEPIC(), rmf_energy_index = rmf_energy_index, + rmf_matrix_index = rmf_matrix_index, kwargs..., - )..., -) - -make_label(data::XmmData) = data.observation_id - -@_forward_SpectralData_api XmmData.data - -function Base.show(io::IO, @nospecialize(data::XmmData{T})) where {T} - print(io, "XmmData[dev=$(data.device),obs_id=$(data.observation_id)]") + ) end -function _printinfo(io, data::XmmData{T}) where {T} - descr = """XmmData for $(Base.typename(typeof(data.device)).name): - . Object : $(data.object) - . Observation ID : $(data.observation_id) - . Exposure ID : $(data.exposure_id) - """ - print(io, descr) - _printinfo(io, data.data) -end - - export NuStarData, XmmData, XmmEPIC diff --git a/src/datasets/ogipdataset.jl b/src/datasets/ogipdataset.jl index 4be6480e..4c13fa3f 100644 --- a/src/datasets/ogipdataset.jl +++ b/src/datasets/ogipdataset.jl @@ -1,5 +1,6 @@ -struct OGIPDataset{T,H} <: AbstractDataset - data::SpectralData{T} +struct OGIPData end + +struct OGIPMetadata{H} paths::SpectralDataPaths observation_id::String exposure_id::String @@ -7,7 +8,7 @@ struct OGIPDataset{T,H} <: AbstractDataset header::H end -function load_ogip_dataset( +function ogip_dataset( spec_path; hdu = 2, background = nothing, @@ -27,24 +28,30 @@ function load_ogip_dataset( exposure_id = haskey(header, "EXP_ID") ? header["EXP_ID"] : "[no exposure id]" object = haskey(header, "OBJECT") ? header["OBJECT"] : "[no object]" - data = SpectralData(paths) - (data, paths, obs_id, exposure_id, object, header) + (; paths, obs_id, exposure_id, object, header) end -OGIPDataset(spec_path; kwargs...) = OGIPDataset(load_ogip_dataset(spec_path; kwargs...)...) - -@_forward_SpectralData_api OGIPDataset.data - -make_label(d::OGIPDataset) = d.observation_id +function OGIPDataset(spec_path; tag = OGIPData(), kwargs...) + info = ogip_dataset(spec_path; kwargs...) + metadata = + OGIPMetadata(info.paths, info.obs_id, info.exposure_id, info.object, info.header) + SpectralData(info.paths, tag = tag, user_data = metadata) +end -function _printinfo(io, data::OGIPDataset{T}) where {T} - descr = """OGIPDataset: - . Object : $(data.object) - . Observation ID : $(data.observation_id) - . Exposure ID : $(data.exposure_id) +make_label(data::SpectralData{T,<:Union{<:AbstractInstrument,<:OGIPData}}) where {T} = + data.user_data.observation_id + +function _printinfo( + io, + data::SpectralData{T,K}, +) where {T,K<:Union{<:AbstractInstrument,<:OGIPData}} + descr = """SpectralDataset{$K}: + . Object : $(data.user_data.object) + . Observation ID : $(data.user_data.observation_id) + . Exposure ID : $(data.user_data.exposure_id) """ print(io, descr) - _printinfo(io, data.data) + print_spectral_data_info(io, data) end export OGIPDataset diff --git a/src/datasets/spectraldata.jl b/src/datasets/spectraldata.jl index 9b5c6c28..548915a9 100644 --- a/src/datasets/spectraldata.jl +++ b/src/datasets/spectraldata.jl @@ -85,7 +85,7 @@ If the spectrum and repsonse matrix have already been loaded seperately, use ancillary = nothing, ) """ -mutable struct SpectralData{T} <: AbstractDataset +mutable struct SpectralData{T,Tag,D} <: AbstractDataset "Observed spectrum to be fitted." spectrum::Spectrum{T} "Instrument response." @@ -102,27 +102,42 @@ mutable struct SpectralData{T} <: AbstractDataset domain::Vector{T} "Mask representing which bins are to be included in the fit." data_mask::BitVector + user_data::D end # constructor -function SpectralData(paths::SpectralDataPaths; kwargs...) +function SpectralData( + paths::SpectralDataPaths; + tag = nothing, + user_data = nothing, + kwargs..., +) spec, resp, back, anc = _read_all_ogip(paths; kwargs...) - SpectralData(spec, resp; background = back, ancillary = anc) + SpectralData( + spec, + resp; + background = back, + ancillary = anc, + tag = tag, + user_data = user_data, + ) end function SpectralData( - spectrum::Spectrum, + spectrum::Spectrum{T}, response::ResponseMatrix; # try to match the domains of the response matrix to the data match_domains = true, background = nothing, ancillary = nothing, -) + tag = nothing, + user_data = nothing, +) where {T} domain = _make_domain_vector(spectrum, response) energy_low, energy_high = _make_energy_vector(spectrum, response) data_mask = BitVector(fill(true, size(spectrum.data))) - data = SpectralData( + data = SpectralData{T,typeof(tag),typeof(user_data)}( spectrum, response, background, @@ -131,6 +146,7 @@ function SpectralData( energy_high, domain, data_mask, + user_data, ) if !check_domains(data) if match_domains @@ -143,7 +159,7 @@ function SpectralData( end function Base.copy(data::SpectralData) - SpectralData( + typeof(data)( data.spectrum, data.response, data.background, @@ -152,6 +168,7 @@ function Base.copy(data::SpectralData) copy(data.energy_high), copy(data.domain), copy(data.data_mask), + copy(data.user_data), ) end @@ -523,76 +540,6 @@ function invokemodel( invokemodel!(output, domain, model, cache)[data.data_mask] end -macro _forward_SpectralData_api(args) - if args.head !== :. - error("Bad syntax") - end - T, field = args.args - quote - SpectralFitting.supports(t::Type{<:$(T)}) = (ContiguouslyBinned(),) - SpectralFitting.preferred_units(t::Type{<:$(T)}, u::AbstractStatistic) = - SpectralFitting.preferred_units(SpectralData, u) - SpectralFitting.make_output_domain( - layout::SpectralFitting.AbstractLayout, - t::$(T), - ) = SpectralFitting.make_output_domain(layout, getfield(t, $(field))) - SpectralFitting.make_model_domain(layout::SpectralFitting.AbstractLayout, t::$(T)) = - SpectralFitting.make_model_domain(layout, getfield(t, $(field))) - SpectralFitting.make_domain_variance( - layout::SpectralFitting.AbstractLayout, - t::$(T), - ) = SpectralFitting.make_domain_variance(layout, getfield(t, $(field))) - SpectralFitting.make_objective(layout::SpectralFitting.AbstractLayout, t::$(T)) = - SpectralFitting.make_objective(layout, getfield(t, $(field))) - SpectralFitting.make_objective_variance( - layout::SpectralFitting.AbstractLayout, - t::$(T), - ) = SpectralFitting.make_objective_variance(layout, getfield(t, $(field))) - SpectralFitting.objective_transformer( - layout::SpectralFitting.AbstractLayout, - t::$(T), - ) = SpectralFitting.objective_transformer(layout, getfield(t, $(field))) - SpectralFitting.regroup!(t::$(T), args...; kwargs...) = - SpectralFitting.regroup!(getfield(t, $(field)), args...; kwargs...) - SpectralFitting.restrict_domain!(t::$(T), args...) = - SpectralFitting.restrict_domain!(getfield(t, $(field)), args...) - SpectralFitting.mask_energies!(t::$(T), args...) = - SpectralFitting.mask_energies!(getfield(t, $(field)), args...) - SpectralFitting.drop_channels!(t::$(T), args...) = - SpectralFitting.drop_channels!(getfield(t, $(field)), args...) - SpectralFitting.drop_bad_channels!(t::$(T)) = - SpectralFitting.drop_bad_channels!(getfield(t, $(field))) - SpectralFitting.drop_negative_channels!(t::$(T)) = - SpectralFitting.drop_negative_channels!(getfield(t, $(field))) - SpectralFitting.normalize!(t::$(T)) = - SpectralFitting.normalize!(getfield(t, $(field))) - SpectralFitting.objective_units(t::$(T)) = - SpectralFitting.objective_units(getfield(t, $(field))) - SpectralFitting.spectrum_energy(t::$(T)) = - SpectralFitting.spectrum_energy(getfield(t, $(field))) - SpectralFitting.bin_widths(t::$(T)) = - SpectralFitting.bin_widths(getfield(t, $(field))) - SpectralFitting.subtract_background!(t::$(T), args...) = - SpectralFitting.subtract_background!(getfield(t, $(field)), args...) - SpectralFitting.set_domain!(t::$(T), args...) = - SpectralFitting.set_domain!(getfield(t, $(field)), args...) - SpectralFitting.error_statistic(t::$(T)) = - SpectralFitting.error_statistic(getfield(t, $(field))) - SpectralFitting.set_units!(t::$(T), args...) = - SpectralFitting.set_units!(getfield(t, $(field)), args...) - SpectralFitting.background_dataset(t::$(T), args...; kwargs...) = - SpectralFitting.background_dataset(getfield(t, $(field)), args...; kwargs...) - SpectralFitting.rescale_background!(t::$(T), args...; kwargs...) = - SpectralFitting.rescale_background!( - getfield(t, $(field)), - args...; - kwargs..., - ) - SpectralFitting.rescale!(t::$(T), args...; kwargs...) = - SpectralFitting.rescale!(getfield(t, $(field)), args...; kwargs...) - end |> esc -end - # printing utilities function Base.show(io::IO, @nospecialize(data::SpectralData)) @@ -600,6 +547,10 @@ function Base.show(io::IO, @nospecialize(data::SpectralData)) end function _printinfo(io, data::SpectralData{T}) where {T} + print_spectral_data_info(io, data) +end + +function print_spectral_data_info(io, data::SpectralData{T}) where {T} domain = @views data.domain ce_min = @views prettyfloat.(minimum(data.energy_low[data.data_mask])) ce_max = @views prettyfloat.(maximum(data.energy_high[data.data_mask])) diff --git a/test/fitting/test-sample-data.jl b/test/fitting/test-sample-data.jl index 3831cd19..2338db50 100644 --- a/test/fitting/test-sample-data.jl +++ b/test/fitting/test-sample-data.jl @@ -14,10 +14,7 @@ function prepare_data!(data, low, high) end # path to the data directory -data1 = SpectralFitting.XmmData( - SpectralFitting.XmmEPIC(), - joinpath(testdir, "xmm/pn_spec_grp.fits"), -) +data1 = SpectralFitting.XmmData(joinpath(testdir, "xmm/pn_spec_grp.fits")) prepare_data!(data1, 0.8, 10.0) model = PhotoelectricAbsorption() * XS_PowerLaw() + XS_Laor() diff --git a/test/simulation/test-sample-data-sim.jl b/test/simulation/test-sample-data-sim.jl index a1ee214a..3451183b 100644 --- a/test/simulation/test-sample-data-sim.jl +++ b/test/simulation/test-sample-data-sim.jl @@ -1,21 +1,12 @@ using Test, SpectralFitting # path to the data directory -data1 = SpectralFitting.XmmData( - SpectralFitting.XmmEPIC(), - joinpath(testdir, "xmm/pn_spec_grp.fits"), -) +data1 = SpectralFitting.XmmData(joinpath(testdir, "xmm/pn_spec_grp.fits")) # smoke test model = GaussianLine() + PowerLaw(a = FitParam(0.2)) -sim = simulate( - model, - data1.data.response, - data1.data.ancillary; - seed = 8, - exposure_time = 1e1, -) +sim = simulate(model, data1.response, data1.ancillary; seed = 8, exposure_time = 1e1) # TODO: add a fit. can't do it at the moment as the simulated datasets don't # support masking the model domain, so we have singular values at high energies