From 8f7c5e5f2bf40f6d97806b164a97f4f8e34c09a0 Mon Sep 17 00:00:00 2001 From: PharmCat Date: Mon, 1 Jan 2024 20:43:10 +0300 Subject: [PATCH] wts first concept --- docs/src/details.md | 14 +++++++++++++- src/linearalgebra.jl | 14 ++++++++++++++ src/lmm.jl | 22 ++++++++++++++++------ src/lmmdata.jl | 2 +- src/miboot.jl | 32 +++++++++++++++++++++++++++----- src/utils.jl | 22 ++++++++++++++++++++-- test/csv/df0.csv | 42 +++++++++++++++++++++--------------------- test/test.jl | 17 +++++++++++++++++ test/testdata.jl | 2 +- 9 files changed, 130 insertions(+), 37 deletions(-) diff --git a/docs/src/details.md b/docs/src/details.md index 4ecea801..b1d0d7f9 100644 --- a/docs/src/details.md +++ b/docs/src/details.md @@ -33,7 +33,7 @@ logREML(\theta,\beta) = -\frac{N-p}{2} - \frac{1}{2}\sum_{i=1}^nlog|V_{\theta, i -\frac{1}{2}log|\sum_{i=1}^nX_i'V_{\theta, i}^{-1}X_i|-\frac{1}{2}\sum_{i=1}^n(y_i - X_{i}\beta)'V_{\theta, i}^{-1}(y_i - X_{i}\beta) ``` -Actually ```L(\theta) = -2logREML = L_1(\theta) + L_2(\theta) + L_3(\theta) + c`` used for optimization, where: +Actually ```L(\theta) = -2logREML = L_1(\theta) + L_2(\theta) + L_3(\theta) + c``` used for optimization, where: ```math L_1(\theta) = \frac{1}{2}\sum_{i=1}^nlog|V_{i}| \\ @@ -51,6 +51,18 @@ L_3(\theta) = \frac{1}{2}\sum_{i=1}^n(y_i - X_{i}\beta)'V_i^{-1}(y_i - X_{i}\bet \mathcal{H}\mathcal{L}(\theta) = \mathcal{H}L_1(\theta) + \mathcal{H}L_2(\theta) + \mathcal{H} L_3(\theta) ``` +#### Weights + +If weights defined: + +```math +V_{i} = Z_{i} G Z_i'+ W^{- \frac{1}{2}}_i R_{i} W^{- \frac{1}{2}}_i +``` + + +where ```W``` - diagonal matrix of weights. + + ##### Initial step Initial (first) step before optimization may be done: diff --git a/src/linearalgebra.jl b/src/linearalgebra.jl index 05f6963a..79d5093f 100644 --- a/src/linearalgebra.jl +++ b/src/linearalgebra.jl @@ -139,6 +139,20 @@ Change θ. end θ end +# Diagonal(b) * A * Diagonal(b) - chnage only A upper triangle +@noinline function mulβdαβd!(A::AbstractMatrix, b::AbstractVector) + q = size(A, 1) + p = size(A, 2) + if !(q == p == length(b)) throw(DimensionMismatch("size(A, 1) and size(A, 2) should be equal length(b)")) end + for n in 1:p + @simd for m in 1:n + @inbounds A[m, n] *= b[m] * b[n] + end + end + A +end + + ################################################################################ @inline function tmul_unsafe(rz, θ::AbstractVector{T}) where T vec = zeros(T, size(rz, 1)) diff --git a/src/lmm.jl b/src/lmm.jl index 798f1b33..42df2c2b 100644 --- a/src/lmm.jl +++ b/src/lmm.jl @@ -11,7 +11,7 @@ struct ModelStructure end """ - LMM(model, data; contrasts=Dict{Symbol,Any}(), random::Union{Nothing, VarEffect, Vector{VarEffect}} = nothing, repeated::Union{Nothing, VarEffect} = nothing) + LMM(model, data; contrasts=Dict{Symbol,Any}(), random::Union{Nothing, VarEffect, Vector{VarEffect}} = nothing, repeated::Union{Nothing, VarEffect} = nothing, wts::Union{Nothing, AbstractVector, AbstractString, Symbol} = nothing) Make Linear-Mixed Model object. @@ -25,9 +25,11 @@ Make Linear-Mixed Model object. `repeated`: is a repeated effect (only one) +`wts`: regression weights (residuals). + See also: [`@lmmformula`](@ref) """ -struct LMM{T<:AbstractFloat} <: MetidaModel +struct LMM{T <: AbstractFloat, W <: Union{LMMWts, Nothing}} <: MetidaModel model::FormulaTerm f::FormulaTerm modstr::ModelStructure @@ -51,11 +53,11 @@ struct LMM{T<:AbstractFloat} <: MetidaModel rankx::Int, result::ModelResult, maxvcbl::Int, - wts::Union{Nothing, LMMWts}, - log::Vector{LMMLogMsg}) where T - new{T}(model, f, modstr, covstr, data, dv, nfixed, rankx, result, maxvcbl, wts, log) + wts::W, + log::Vector{LMMLogMsg}) where T where W <: Union{LMMWts, Nothing} + new{T, W}(model, f, modstr, covstr, data, dv, nfixed, rankx, result, maxvcbl, wts, log) end - function LMM(model, data; contrasts=Dict{Symbol,Any}(), random::Union{Nothing, VarEffect, Vector{VarEffect}} = nothing, repeated::Union{Nothing, VarEffect} = nothing, wts = nothing) + function LMM(model, data; contrasts=Dict{Symbol,Any}(), random::Union{Nothing, VarEffect, Vector{VarEffect}} = nothing, repeated::Union{Nothing, VarEffect} = nothing, wts::Union{Nothing, AbstractVector, AbstractString, Symbol} = nothing) #need check responce - Float if !Tables.istable(data) error("Data not a table!") end if repeated === nothing && random === nothing @@ -69,6 +71,10 @@ struct LMM{T<:AbstractFloat} <: MetidaModel if !isnothing(repeated) union!(tv, termvars(repeated)) end + if !isnothing(wts) && wts isa Union{AbstractString, Symbol} + if wts isa String wts = Symbol(wts) end + union!(tv, (wts,)) + end ct = Tables.columntable(data) if !(tv ⊆ keys(ct)) error("Some column(s) not found!") end data, data_ = StatsModels.missing_omit(NamedTuple{tuple(tv...)}(ct)) @@ -113,7 +119,11 @@ struct LMM{T<:AbstractFloat} <: MetidaModel if isnothing(wts) lmmwts = nothing else + if wts isa Symbol + wts = Tables.getcolumn(data, wts) + end if length(lmmdata.yv) == length(wts) + if any(x -> x <= zero(x), wts) error("Only cases with positive weights allowed!") end lmmwts = LMMWts(wts, covstr.vcovblock) else @warn "wts count not equal observations count! wts not used." diff --git a/src/lmmdata.jl b/src/lmmdata.jl index 7136ea6f..bf07fa7b 100644 --- a/src/lmmdata.jl +++ b/src/lmmdata.jl @@ -43,7 +43,7 @@ struct LMMWts{T<:AbstractFloat} function LMMWts(wts::Vector{T}, vcovblock) where T sqrtwts = Vector{Vector{T}}(undef, length(vcovblock)) for i in eachindex(vcovblock) - y[i] = sqrt.(view(wts, vcovblock[i])) + sqrtwts[i] = @. inv(sqrt($(view(wts, vcovblock[i])))) end LMMWts(sqrtwts) end diff --git a/src/miboot.jl b/src/miboot.jl index 04409825..e5075ed7 100644 --- a/src/miboot.jl +++ b/src/miboot.jl @@ -28,12 +28,17 @@ struct MILMM{T} <: MetidaModel mrs::MRS wts::Union{Nothing, LMMWts} log::Vector{LMMLogMsg} - function MILMM(lmm::LMM{T}, data) where T + function MILMM(lmm::LMM{T}, data; wts::Union{Nothing, AbstractVector, AbstractString, Symbol} = nothing) where T if !Tables.istable(data) error("Data not a table!") end if !isfitted(lmm) error("LMM should be fitted!") end tv = termvars(lmm.model.rhs) union!(tv, termvars(lmm.covstr.random)) union!(tv, termvars(lmm.covstr.repeated)) + if !isnothing(wts) && wts isa Union{AbstractString, Symbol} + if wts isa String wts = Symbol(wts) end + union!(tv, (wts,)) + end + datam, data_ = StatsModels.missing_omit(NamedTuple{tuple(tv...)}(Tables.columntable(data))) rv = termvars(lmm.model.lhs)[1] rcol = Tables.getcolumn(data, rv)[data_] @@ -53,8 +58,25 @@ struct MILMM{T} <: MetidaModel covstr = CovStructure(lmm.covstr.random, lmm.covstr.repeated, data) dv = LMMDataViews(lmf, lmmdata.yv, covstr.vcovblock) mb = missblocks(dv.yv) - dist = mrsdist(lmm, mb, covstr, dv.xv, dv.yv) - new{T}(lmm, lmm.f, lmm.modstr, covstr, lmmdata, dv, findmax(length, covstr.vcovblock)[1], MRS(mb, dist), lmm.wts, lmmlog) + + if isnothing(wts) + lmmwts = nothing + else + if wts isa Symbol + wts = Tables.getcolumn(data, wts) + end + if length(lmmdata.yv) == length(wts) + lmmwts = LMMWts(wts, covstr.vcovblock) + else + @warn "wts count not equal observations count! wts not used." + lmmwts = nothing + end + end + + + dist = mrsdist(lmm, mb, covstr, lmmwts, dv.xv, dv.yv) + + new{T}(lmm, lmm.f, lmm.modstr, covstr, lmmdata, dv, findmax(length, covstr.vcovblock)[1], MRS(mb, dist), lmmwts, lmmlog) end end struct MILMMResult{T} @@ -511,11 +533,11 @@ function missblocks(yv) vec end # return distribution vector for -function mrsdist(lmm, mb, covstr, xv, yv) +function mrsdist(lmm, mb, covstr, lmmwts, xv, yv) dist = Vector{FullNormal}(undef, length(mb)) #Base.Threads.@threads for i in 1:length(mb) - v = vmatrix(lmm.result.theta, covstr, mb[i][1]) + v = vmatrix(lmm.result.theta, covstr, lmmwts, mb[i][1]) rv = covmatreorder(v, mb[i][2]) dist[i] = mvconddist(rv[1], rv[2], mb[i][2], lmm.result.beta, xv[mb[i][1]], yv[mb[i][1]]) end diff --git a/src/utils.jl b/src/utils.jl index e8ff780a..7cc1d8f9 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -264,6 +264,21 @@ function rmatrix(lmm::LMM{T}, i::Int) where T rmat_base_inc!(R, lmm.result.theta[lmm.covstr.tr[end]], lmm.covstr, i) Symmetric(R) end + +##################################################################### + +function applywts!(::Any, ::Int, ::Nothing) + nothing +end + +function applywts!(V::AbstractMatrix, i::Int, wts::LMMWts) + mulβdαβd!(V, wts.sqrtwts[i]) +end + + +##################################################################### + +##################################################################### """ vmatrix!(V, θ, lmm, i) @@ -272,13 +287,15 @@ Update variance-covariance matrix V for i bolock. Upper triangular updated. function vmatrix!(V, θ, lmm::LMM, i::Int) # pub API gvec = gmatvec(θ, lmm.covstr) rmat_base_inc!(V, θ[lmm.covstr.tr[end]], lmm.covstr, i) + applywts!(V, i, lmm.wts) zgz_base_inc!(V, gvec, lmm.covstr, i) end # !!! Main function REML used -function vmatrix!(V, G, rθ, lmm::LMM, i::Int) +@noinline function vmatrix!(V, G, rθ, lmm::LMM, i::Int) rmat_base_inc!(V, rθ, lmm.covstr, i) + applywts!(V, i, lmm.wts) zgz_base_inc!(V, G, lmm.covstr, i) end @@ -298,10 +315,11 @@ function vmatrix(θ::AbstractVector{T}, lmm::LMM, i::Int) where T Symmetric(V) end # For Multiple Imputation -function vmatrix(θ::Vector, covstr::CovStructure, i::Int) +function vmatrix(θ::Vector, covstr::CovStructure, lmmwts, i::Int) V = zeros(length(covstr.vcovblock[i]), length(covstr.vcovblock[i])) gvec = gmatvec(θ, covstr) rmat_base_inc!(V, θ[covstr.tr[end]], covstr, i) + applywts!(V, i, lmmwts) #type unstable zgz_base_inc!(V, gvec, covstr, i) Symmetric(V) end diff --git a/test/csv/df0.csv b/test/csv/df0.csv index 23e46183..08133e91 100644 --- a/test/csv/df0.csv +++ b/test/csv/df0.csv @@ -1,21 +1,21 @@ -subject,sequence,period,formulation,var,var2 -1,1,1,1,1.0,1.0 -1,1,2,2,1.1,2.0 -1,1,3,1,1.2,3.0 -1,1,4,2,1.3,4.0 -2,1,1,1,2.0,5.0 -2,1,2,2,2.1,6.0 -2,1,3,1,2.4,7.0 -2,1,4,2,2.2,2.0 -3,2,1,2,1.3,3.5 -3,2,2,1,1.5,3.0 -3,2,3,2,1.6,4.0 -3,2,4,1,1.4,5.0 -4,2,1,2,1.5,6.0 -4,2,2,1,1.7,1.2 -4,2,3,2,1.3,3.4 -4,2,4,1,1.4,5.0 -5,2,1,2,1.5,6.0 -5,2,2,1,1.7,7.0 -5,2,3,2,1.2,10.0 -5,2,4,1,1.8,9.0 +"subject","sequence","period","formulation","var","var2","wts" +1,1,1,1,1,1,1 +1,1,2,2,1.1,2,1 +1,1,3,1,1.2,3,1 +1,1,4,2,1.3,4,0.5 +2,1,1,1,2,5,1 +2,1,2,2,2.1,6,1 +2,1,3,1,2.4,7,0.25 +2,1,4,2,2.2,2,0.25 +3,2,1,2,1.3,3.5,1 +3,2,2,1,1.5,3,0.3 +3,2,3,2,1.6,4,0.3 +3,2,4,1,1.4,5,0.3 +4,2,1,2,1.5,6,0.1 +4,2,2,1,1.7,1.2,0.2 +4,2,3,2,1.3,3.4,0.3 +4,2,4,1,1.4,5,0.4 +5,2,1,2,1.5,6,1 +5,2,2,1,1.7,7,2 +5,2,3,2,1.2,10,3 +5,2,4,1,1.8,9,4 diff --git a/test/test.jl b/test/test.jl index 2d931e7a..7aaf3134 100644 --- a/test/test.jl +++ b/test/test.jl @@ -191,6 +191,23 @@ include("testdata.jl") random = 1+var^2|subject:Metida.SI), df0) Metida.fit!(lmmint) @test Metida.m2logreml(lmmint) ≈ 84.23373276096902 atol=1E-6 + + # Wts + + df0.wtsc = fill(0.5, size(df0, 1)) + lmm = Metida.LMM(@formula(var~sequence+period+formulation), df0; + random = Metida.VarEffect(Metida.@covstr(formulation|subject), Metida.DIAG), + wts = df0.wtsc) + fit!(lmm) + @test Metida.m2logreml(lmm) ≈ 16.241112644506067 atol=1E-6 + + lmm = Metida.LMM(@formula(var~sequence+period+formulation), df0; + random = Metida.VarEffect(Metida.@covstr(formulation|subject), Metida.DIAG), + wts = "wts") + fit!(lmm) + @test Metida.m2logreml(lmm) ≈ 17.823729 atol=1E-6 # TEST WITH SPSS 28 + + end ################################################################################ # df0 diff --git a/test/testdata.jl b/test/testdata.jl index 47675174..9532ab19 100644 --- a/test/testdata.jl +++ b/test/testdata.jl @@ -1,7 +1,7 @@ # Metida #Simple dataset -df0 = CSV.File(path*"/csv/df0.csv"; types = [String, String, String, String, Float64, Float64]) |> DataFrame +df0 = CSV.File(path*"/csv/df0.csv"; types = [String, String, String, String, Float64, Float64, Float64]) |> DataFrame df0m = CSV.File(path*"/csv/df0miss.csv"; types = [String, String, String, String, Float64, Float64]) |> DataFrame