Skip to content

Commit

Permalink
feat!: add Tag parameter to SpectralData
Browse files Browse the repository at this point in the history
This enormously simplifies derived datatypes, and removed the API
forwarding macro that was unreliable to maintain. Follows from the
Tagging idea introduced with BinnedData, and allows derived types to be
implemented more easily.

Updated OGIP, XMM and NuSTAR datatypes to use this tagging system
instead of adding new types.

Modified test cases to sync with changes.
  • Loading branch information
fjebaker committed Jan 23, 2025
1 parent 9f4d50c commit 269d327
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 186 deletions.
4 changes: 1 addition & 3 deletions src/SpectralFitting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ import Optimization

using DocStringExtensions

# for future use: mission specific parsing
abstract type AbstractMission end
struct NoMission <: AbstractMission end
abstract type AbstractInstrument end

abstract type AbstractStatistic end
struct ChiSquared <: AbstractStatistic end
Expand Down
87 changes: 12 additions & 75 deletions src/datasets/mission-specifics.jl
Original file line number Diff line number Diff line change
@@ -1,87 +1,24 @@
struct NuSTAR <: AbstractInstrument end
struct XmmEPIC <: AbstractInstrument end

struct NuStarData{T,H} <: AbstractDataset
data::SpectralData{T}
paths::SpectralDataPaths
observation_id::String
exposure_id::String
object::String
header::H
end

NuStarData(spec_path; rmf_matrix_index = 3, rmf_energy_index = 2, kwargs...) = NuStarData(
load_ogip_dataset(
function NuStarData(spec_path; rmf_matrix_index = 3, rmf_energy_index = 2, kwargs...)
OGIPDataset(
spec_path;
rmf_matrix_index = rmf_matrix_index,
tag = NuSTAR(),
rmf_energy_index = rmf_energy_index,
rmf_matrix_index = rmf_matrix_index,
kwargs...,
)...,
)

make_label(data::NuStarData) = data.observation_id

@_forward_SpectralData_api NuStarData.data

function Base.show(io::IO, @nospecialize(data::NuStarData{T})) where {T}
print(io, "NuStarData[obs_id=$(data.observation_id)]")
end

function _printinfo(io, data::NuStarData{T}) where {T}
descr = """NuStarData:
. Object : $(data.object)
. Observation ID : $(data.observation_id)
. Exposure ID : $(data.exposure_id)
"""
print(io, descr)
_printinfo(io, data.data)
)
end


abstract type AbstractXmmNewtonDevice end
struct XmmEPIC <: AbstractXmmNewtonDevice end

struct XmmData{T,H,D} <: AbstractDataset
device::D
data::SpectralData{T}
paths::SpectralDataPaths
observation_id::String
exposure_id::String
object::String
header::H
end

XmmData(
device::AbstractXmmNewtonDevice,
spec_path;
rmf_matrix_index = 2,
rmf_energy_index = 3,
kwargs...,
) = XmmData(
device,
load_ogip_dataset(
function XmmData(spec_path; rmf_matrix_index = 2, rmf_energy_index = 3, kwargs...)
OGIPDataset(
spec_path;
rmf_matrix_index = rmf_matrix_index,
tag = XmmEPIC(),
rmf_energy_index = rmf_energy_index,
rmf_matrix_index = rmf_matrix_index,
kwargs...,
)...,
)

make_label(data::XmmData) = data.observation_id

@_forward_SpectralData_api XmmData.data

function Base.show(io::IO, @nospecialize(data::XmmData{T})) where {T}
print(io, "XmmData[dev=$(data.device),obs_id=$(data.observation_id)]")
)
end

function _printinfo(io, data::XmmData{T}) where {T}
descr = """XmmData for $(Base.typename(typeof(data.device)).name):
. Object : $(data.object)
. Observation ID : $(data.observation_id)
. Exposure ID : $(data.exposure_id)
"""
print(io, descr)
_printinfo(io, data.data)
end


export NuStarData, XmmData, XmmEPIC
39 changes: 23 additions & 16 deletions src/datasets/ogipdataset.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
struct OGIPDataset{T,H} <: AbstractDataset
data::SpectralData{T}
struct OGIPData end

struct OGIPMetadata{H}
paths::SpectralDataPaths
observation_id::String
exposure_id::String
object::String
header::H
end

function load_ogip_dataset(
function ogip_dataset(
spec_path;
hdu = 2,
background = nothing,
Expand All @@ -27,24 +28,30 @@ function load_ogip_dataset(
exposure_id = haskey(header, "EXP_ID") ? header["EXP_ID"] : "[no exposure id]"
object = haskey(header, "OBJECT") ? header["OBJECT"] : "[no object]"

data = SpectralData(paths)
(data, paths, obs_id, exposure_id, object, header)
(; paths, obs_id, exposure_id, object, header)
end

OGIPDataset(spec_path; kwargs...) = OGIPDataset(load_ogip_dataset(spec_path; kwargs...)...)

@_forward_SpectralData_api OGIPDataset.data

make_label(d::OGIPDataset) = d.observation_id
function OGIPDataset(spec_path; tag = OGIPData(), kwargs...)
info = ogip_dataset(spec_path; kwargs...)
metadata =
OGIPMetadata(info.paths, info.obs_id, info.exposure_id, info.object, info.header)
SpectralData(info.paths, tag = tag, user_data = metadata)
end

function _printinfo(io, data::OGIPDataset{T}) where {T}
descr = """OGIPDataset:
. Object : $(data.object)
. Observation ID : $(data.observation_id)
. Exposure ID : $(data.exposure_id)
make_label(data::SpectralData{T,<:Union{<:AbstractInstrument,<:OGIPData}}) where {T} =
data.user_data.observation_id

function _printinfo(
io,
data::SpectralData{T,K},
) where {T,K<:Union{<:AbstractInstrument,<:OGIPData}}
descr = """SpectralDataset{$K}:
. Object : $(data.user_data.object)
. Observation ID : $(data.user_data.observation_id)
. Exposure ID : $(data.user_data.exposure_id)
"""
print(io, descr)
_printinfo(io, data.data)
print_spectral_data_info(io, data)
end

export OGIPDataset
105 changes: 28 additions & 77 deletions src/datasets/spectraldata.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ If the spectrum and repsonse matrix have already been loaded seperately, use
ancillary = nothing,
)
"""
mutable struct SpectralData{T} <: AbstractDataset
mutable struct SpectralData{T,Tag,D} <: AbstractDataset
"Observed spectrum to be fitted."
spectrum::Spectrum{T}
"Instrument response."
Expand All @@ -102,27 +102,42 @@ mutable struct SpectralData{T} <: AbstractDataset
domain::Vector{T}
"Mask representing which bins are to be included in the fit."
data_mask::BitVector
user_data::D
end

# constructor

function SpectralData(paths::SpectralDataPaths; kwargs...)
function SpectralData(
paths::SpectralDataPaths;
tag = nothing,
user_data = nothing,
kwargs...,
)
spec, resp, back, anc = _read_all_ogip(paths; kwargs...)
SpectralData(spec, resp; background = back, ancillary = anc)
SpectralData(
spec,
resp;
background = back,
ancillary = anc,
tag = tag,
user_data = user_data,
)
end

function SpectralData(
spectrum::Spectrum,
spectrum::Spectrum{T},
response::ResponseMatrix;
# try to match the domains of the response matrix to the data
match_domains = true,
background = nothing,
ancillary = nothing,
)
tag = nothing,
user_data = nothing,
) where {T}
domain = _make_domain_vector(spectrum, response)
energy_low, energy_high = _make_energy_vector(spectrum, response)
data_mask = BitVector(fill(true, size(spectrum.data)))
data = SpectralData(
data = SpectralData{T,typeof(tag),typeof(user_data)}(
spectrum,
response,
background,
Expand All @@ -131,6 +146,7 @@ function SpectralData(
energy_high,
domain,
data_mask,
user_data,
)
if !check_domains(data)
if match_domains
Expand All @@ -143,7 +159,7 @@ function SpectralData(
end

function Base.copy(data::SpectralData)
SpectralData(
typeof(data)(
data.spectrum,
data.response,
data.background,
Expand All @@ -152,6 +168,7 @@ function Base.copy(data::SpectralData)
copy(data.energy_high),
copy(data.domain),
copy(data.data_mask),
copy(data.user_data),
)
end

Expand Down Expand Up @@ -523,83 +540,17 @@ function invokemodel(
invokemodel!(output, domain, model, cache)[data.data_mask]
end

macro _forward_SpectralData_api(args)
if args.head !== :.
error("Bad syntax")
end
T, field = args.args
quote
SpectralFitting.supports(t::Type{<:$(T)}) = (ContiguouslyBinned(),)
SpectralFitting.preferred_units(t::Type{<:$(T)}, u::AbstractStatistic) =
SpectralFitting.preferred_units(SpectralData, u)
SpectralFitting.make_output_domain(
layout::SpectralFitting.AbstractLayout,
t::$(T),
) = SpectralFitting.make_output_domain(layout, getfield(t, $(field)))
SpectralFitting.make_model_domain(layout::SpectralFitting.AbstractLayout, t::$(T)) =
SpectralFitting.make_model_domain(layout, getfield(t, $(field)))
SpectralFitting.make_domain_variance(
layout::SpectralFitting.AbstractLayout,
t::$(T),
) = SpectralFitting.make_domain_variance(layout, getfield(t, $(field)))
SpectralFitting.make_objective(layout::SpectralFitting.AbstractLayout, t::$(T)) =
SpectralFitting.make_objective(layout, getfield(t, $(field)))
SpectralFitting.make_objective_variance(
layout::SpectralFitting.AbstractLayout,
t::$(T),
) = SpectralFitting.make_objective_variance(layout, getfield(t, $(field)))
SpectralFitting.objective_transformer(
layout::SpectralFitting.AbstractLayout,
t::$(T),
) = SpectralFitting.objective_transformer(layout, getfield(t, $(field)))
SpectralFitting.regroup!(t::$(T), args...; kwargs...) =
SpectralFitting.regroup!(getfield(t, $(field)), args...; kwargs...)
SpectralFitting.restrict_domain!(t::$(T), args...) =
SpectralFitting.restrict_domain!(getfield(t, $(field)), args...)
SpectralFitting.mask_energies!(t::$(T), args...) =
SpectralFitting.mask_energies!(getfield(t, $(field)), args...)
SpectralFitting.drop_channels!(t::$(T), args...) =
SpectralFitting.drop_channels!(getfield(t, $(field)), args...)
SpectralFitting.drop_bad_channels!(t::$(T)) =
SpectralFitting.drop_bad_channels!(getfield(t, $(field)))
SpectralFitting.drop_negative_channels!(t::$(T)) =
SpectralFitting.drop_negative_channels!(getfield(t, $(field)))
SpectralFitting.normalize!(t::$(T)) =
SpectralFitting.normalize!(getfield(t, $(field)))
SpectralFitting.objective_units(t::$(T)) =
SpectralFitting.objective_units(getfield(t, $(field)))
SpectralFitting.spectrum_energy(t::$(T)) =
SpectralFitting.spectrum_energy(getfield(t, $(field)))
SpectralFitting.bin_widths(t::$(T)) =
SpectralFitting.bin_widths(getfield(t, $(field)))
SpectralFitting.subtract_background!(t::$(T), args...) =
SpectralFitting.subtract_background!(getfield(t, $(field)), args...)
SpectralFitting.set_domain!(t::$(T), args...) =
SpectralFitting.set_domain!(getfield(t, $(field)), args...)
SpectralFitting.error_statistic(t::$(T)) =
SpectralFitting.error_statistic(getfield(t, $(field)))
SpectralFitting.set_units!(t::$(T), args...) =
SpectralFitting.set_units!(getfield(t, $(field)), args...)
SpectralFitting.background_dataset(t::$(T), args...; kwargs...) =
SpectralFitting.background_dataset(getfield(t, $(field)), args...; kwargs...)
SpectralFitting.rescale_background!(t::$(T), args...; kwargs...) =
SpectralFitting.rescale_background!(
getfield(t, $(field)),
args...;
kwargs...,
)
SpectralFitting.rescale!(t::$(T), args...; kwargs...) =
SpectralFitting.rescale!(getfield(t, $(field)), args...; kwargs...)
end |> esc
end

# printing utilities

function Base.show(io::IO, @nospecialize(data::SpectralData))
print(io, "SpectralData[$(data.spectrum.telescope_name)]")
end

function _printinfo(io, data::SpectralData{T}) where {T}
print_spectral_data_info(io, data)
end

function print_spectral_data_info(io, data::SpectralData{T}) where {T}
domain = @views data.domain
ce_min = @views prettyfloat.(minimum(data.energy_low[data.data_mask]))
ce_max = @views prettyfloat.(maximum(data.energy_high[data.data_mask]))
Expand Down
5 changes: 1 addition & 4 deletions test/fitting/test-sample-data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@ function prepare_data!(data, low, high)
end

# path to the data directory
data1 = SpectralFitting.XmmData(
SpectralFitting.XmmEPIC(),
joinpath(testdir, "xmm/pn_spec_grp.fits"),
)
data1 = SpectralFitting.XmmData(joinpath(testdir, "xmm/pn_spec_grp.fits"))
prepare_data!(data1, 0.8, 10.0)

model = PhotoelectricAbsorption() * XS_PowerLaw() + XS_Laor()
Expand Down
13 changes: 2 additions & 11 deletions test/simulation/test-sample-data-sim.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,12 @@
using Test, SpectralFitting

# path to the data directory
data1 = SpectralFitting.XmmData(
SpectralFitting.XmmEPIC(),
joinpath(testdir, "xmm/pn_spec_grp.fits"),
)
data1 = SpectralFitting.XmmData(joinpath(testdir, "xmm/pn_spec_grp.fits"))


# smoke test
model = GaussianLine() + PowerLaw(a = FitParam(0.2))
sim = simulate(
model,
data1.data.response,
data1.data.ancillary;
seed = 8,
exposure_time = 1e1,
)
sim = simulate(model, data1.response, data1.ancillary; seed = 8, exposure_time = 1e1)

# TODO: add a fit. can't do it at the moment as the simulated datasets don't
# support masking the model domain, so we have singular values at high energies

0 comments on commit 269d327

Please sign in to comment.