Skip to content

Commit

Permalink
wts first concept
Browse files Browse the repository at this point in the history
  • Loading branch information
PharmCat committed Jan 1, 2024
1 parent 4a509fb commit 8f7c5e5
Show file tree
Hide file tree
Showing 9 changed files with 130 additions and 37 deletions.
14 changes: 13 additions & 1 deletion docs/src/details.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ logREML(\theta,\beta) = -\frac{N-p}{2} - \frac{1}{2}\sum_{i=1}^nlog|V_{\theta, i
-\frac{1}{2}log|\sum_{i=1}^nX_i'V_{\theta, i}^{-1}X_i|-\frac{1}{2}\sum_{i=1}^n(y_i - X_{i}\beta)'V_{\theta, i}^{-1}(y_i - X_{i}\beta)
```

Actually ```L(\theta) = -2logREML = L_1(\theta) + L_2(\theta) + L_3(\theta) + c`` used for optimization, where:
Actually ```L(\theta) = -2logREML = L_1(\theta) + L_2(\theta) + L_3(\theta) + c``` used for optimization, where:

```math
L_1(\theta) = \frac{1}{2}\sum_{i=1}^nlog|V_{i}| \\
Expand All @@ -51,6 +51,18 @@ L_3(\theta) = \frac{1}{2}\sum_{i=1}^n(y_i - X_{i}\beta)'V_i^{-1}(y_i - X_{i}\bet
\mathcal{H}\mathcal{L}(\theta) = \mathcal{H}L_1(\theta) + \mathcal{H}L_2(\theta) + \mathcal{H} L_3(\theta)
```

#### Weights

If weights defined:

```math
V_{i} = Z_{i} G Z_i'+ W^{- \frac{1}{2}}_i R_{i} W^{- \frac{1}{2}}_i
```


where ```W``` - diagonal matrix of weights.


##### Initial step

Initial (first) step before optimization may be done:
Expand Down
14 changes: 14 additions & 0 deletions src/linearalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,20 @@ Change θ.
end
θ
end
# Diagonal(b) * A * Diagonal(b) - chnage only A upper triangle
@noinline function mulβdαβd!(A::AbstractMatrix, b::AbstractVector)
q = size(A, 1)
p = size(A, 2)
if !(q == p == length(b)) throw(DimensionMismatch("size(A, 1) and size(A, 2) should be equal length(b)")) end
for n in 1:p
@simd for m in 1:n
@inbounds A[m, n] *= b[m] * b[n]
end

Check warning on line 150 in src/linearalgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/linearalgebra.jl#L150

Added line #L150 was not covered by tests
end
A
end


################################################################################
@inline function tmul_unsafe(rz, θ::AbstractVector{T}) where T
vec = zeros(T, size(rz, 1))
Expand Down
22 changes: 16 additions & 6 deletions src/lmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ struct ModelStructure
end

"""
LMM(model, data; contrasts=Dict{Symbol,Any}(), random::Union{Nothing, VarEffect, Vector{VarEffect}} = nothing, repeated::Union{Nothing, VarEffect} = nothing)
LMM(model, data; contrasts=Dict{Symbol,Any}(), random::Union{Nothing, VarEffect, Vector{VarEffect}} = nothing, repeated::Union{Nothing, VarEffect} = nothing, wts::Union{Nothing, AbstractVector, AbstractString, Symbol} = nothing)
Make Linear-Mixed Model object.
Expand All @@ -25,9 +25,11 @@ Make Linear-Mixed Model object.
`repeated`: is a repeated effect (only one)
`wts`: regression weights (residuals).
See also: [`@lmmformula`](@ref)
"""
struct LMM{T<:AbstractFloat} <: MetidaModel
struct LMM{T <: AbstractFloat, W <: Union{LMMWts, Nothing}} <: MetidaModel
model::FormulaTerm
f::FormulaTerm
modstr::ModelStructure
Expand All @@ -51,11 +53,11 @@ struct LMM{T<:AbstractFloat} <: MetidaModel
rankx::Int,
result::ModelResult,
maxvcbl::Int,
wts::Union{Nothing, LMMWts},
log::Vector{LMMLogMsg}) where T
new{T}(model, f, modstr, covstr, data, dv, nfixed, rankx, result, maxvcbl, wts, log)
wts::W,
log::Vector{LMMLogMsg}) where T where W <: Union{LMMWts, Nothing}
new{T, W}(model, f, modstr, covstr, data, dv, nfixed, rankx, result, maxvcbl, wts, log)
end
function LMM(model, data; contrasts=Dict{Symbol,Any}(), random::Union{Nothing, VarEffect, Vector{VarEffect}} = nothing, repeated::Union{Nothing, VarEffect} = nothing, wts = nothing)
function LMM(model, data; contrasts=Dict{Symbol,Any}(), random::Union{Nothing, VarEffect, Vector{VarEffect}} = nothing, repeated::Union{Nothing, VarEffect} = nothing, wts::Union{Nothing, AbstractVector, AbstractString, Symbol} = nothing)
#need check responce - Float
if !Tables.istable(data) error("Data not a table!") end
if repeated === nothing && random === nothing
Expand All @@ -69,6 +71,10 @@ struct LMM{T<:AbstractFloat} <: MetidaModel
if !isnothing(repeated)
union!(tv, termvars(repeated))
end
if !isnothing(wts) && wts isa Union{AbstractString, Symbol}
if wts isa String wts = Symbol(wts) end
union!(tv, (wts,))
end
ct = Tables.columntable(data)
if !(tv keys(ct)) error("Some column(s) not found!") end
data, data_ = StatsModels.missing_omit(NamedTuple{tuple(tv...)}(ct))
Expand Down Expand Up @@ -113,7 +119,11 @@ struct LMM{T<:AbstractFloat} <: MetidaModel
if isnothing(wts)
lmmwts = nothing
else
if wts isa Symbol
wts = Tables.getcolumn(data, wts)
end
if length(lmmdata.yv) == length(wts)
if any(x -> x <= zero(x), wts) error("Only cases with positive weights allowed!") end
lmmwts = LMMWts(wts, covstr.vcovblock)
else
@warn "wts count not equal observations count! wts not used."
Expand Down
2 changes: 1 addition & 1 deletion src/lmmdata.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ struct LMMWts{T<:AbstractFloat}
function LMMWts(wts::Vector{T}, vcovblock) where T
sqrtwts = Vector{Vector{T}}(undef, length(vcovblock))
for i in eachindex(vcovblock)
y[i] = sqrt.(view(wts, vcovblock[i]))
sqrtwts[i] = @. inv(sqrt($(view(wts, vcovblock[i]))))
end
LMMWts(sqrtwts)
end
Expand Down
32 changes: 27 additions & 5 deletions src/miboot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,17 @@ struct MILMM{T} <: MetidaModel
mrs::MRS
wts::Union{Nothing, LMMWts}
log::Vector{LMMLogMsg}
function MILMM(lmm::LMM{T}, data) where T
function MILMM(lmm::LMM{T}, data; wts::Union{Nothing, AbstractVector, AbstractString, Symbol} = nothing) where T
if !Tables.istable(data) error("Data not a table!") end
if !isfitted(lmm) error("LMM should be fitted!") end
tv = termvars(lmm.model.rhs)
union!(tv, termvars(lmm.covstr.random))
union!(tv, termvars(lmm.covstr.repeated))
if !isnothing(wts) && wts isa Union{AbstractString, Symbol}
if wts isa String wts = Symbol(wts) end
union!(tv, (wts,))

Check warning on line 39 in src/miboot.jl

View check run for this annotation

Codecov / codecov/patch

src/miboot.jl#L38-L39

Added lines #L38 - L39 were not covered by tests
end

datam, data_ = StatsModels.missing_omit(NamedTuple{tuple(tv...)}(Tables.columntable(data)))
rv = termvars(lmm.model.lhs)[1]
rcol = Tables.getcolumn(data, rv)[data_]
Expand All @@ -53,8 +58,25 @@ struct MILMM{T} <: MetidaModel
covstr = CovStructure(lmm.covstr.random, lmm.covstr.repeated, data)
dv = LMMDataViews(lmf, lmmdata.yv, covstr.vcovblock)
mb = missblocks(dv.yv)
dist = mrsdist(lmm, mb, covstr, dv.xv, dv.yv)
new{T}(lmm, lmm.f, lmm.modstr, covstr, lmmdata, dv, findmax(length, covstr.vcovblock)[1], MRS(mb, dist), lmm.wts, lmmlog)

if isnothing(wts)
lmmwts = nothing
else
if wts isa Symbol
wts = Tables.getcolumn(data, wts)

Check warning on line 66 in src/miboot.jl

View check run for this annotation

Codecov / codecov/patch

src/miboot.jl#L65-L66

Added lines #L65 - L66 were not covered by tests
end
if length(lmmdata.yv) == length(wts)
lmmwts = LMMWts(wts, covstr.vcovblock)

Check warning on line 69 in src/miboot.jl

View check run for this annotation

Codecov / codecov/patch

src/miboot.jl#L68-L69

Added lines #L68 - L69 were not covered by tests
else
@warn "wts count not equal observations count! wts not used."
lmmwts = nothing

Check warning on line 72 in src/miboot.jl

View check run for this annotation

Codecov / codecov/patch

src/miboot.jl#L71-L72

Added lines #L71 - L72 were not covered by tests
end
end


dist = mrsdist(lmm, mb, covstr, lmmwts, dv.xv, dv.yv)

new{T}(lmm, lmm.f, lmm.modstr, covstr, lmmdata, dv, findmax(length, covstr.vcovblock)[1], MRS(mb, dist), lmmwts, lmmlog)
end
end
struct MILMMResult{T}
Expand Down Expand Up @@ -511,11 +533,11 @@ function missblocks(yv)
vec
end
# return distribution vector for
function mrsdist(lmm, mb, covstr, xv, yv)
function mrsdist(lmm, mb, covstr, lmmwts, xv, yv)
dist = Vector{FullNormal}(undef, length(mb))
#Base.Threads.@threads
for i in 1:length(mb)
v = vmatrix(lmm.result.theta, covstr, mb[i][1])
v = vmatrix(lmm.result.theta, covstr, lmmwts, mb[i][1])
rv = covmatreorder(v, mb[i][2])
dist[i] = mvconddist(rv[1], rv[2], mb[i][2], lmm.result.beta, xv[mb[i][1]], yv[mb[i][1]])
end
Expand Down
22 changes: 20 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,21 @@ function rmatrix(lmm::LMM{T}, i::Int) where T
rmat_base_inc!(R, lmm.result.theta[lmm.covstr.tr[end]], lmm.covstr, i)
Symmetric(R)
end

#####################################################################

function applywts!(::Any, ::Int, ::Nothing)
nothing
end

function applywts!(V::AbstractMatrix, i::Int, wts::LMMWts)
mulβdαβd!(V, wts.sqrtwts[i])
end


#####################################################################

#####################################################################
"""
vmatrix!(V, θ, lmm, i)
Expand All @@ -272,13 +287,15 @@ Update variance-covariance matrix V for i bolock. Upper triangular updated.
function vmatrix!(V, θ, lmm::LMM, i::Int) # pub API
gvec = gmatvec(θ, lmm.covstr)
rmat_base_inc!(V, θ[lmm.covstr.tr[end]], lmm.covstr, i)
applywts!(V, i, lmm.wts)
zgz_base_inc!(V, gvec, lmm.covstr, i)

end

# !!! Main function REML used
function vmatrix!(V, G, rθ, lmm::LMM, i::Int)
@noinline function vmatrix!(V, G, rθ, lmm::LMM, i::Int)
rmat_base_inc!(V, rθ, lmm.covstr, i)
applywts!(V, i, lmm.wts)
zgz_base_inc!(V, G, lmm.covstr, i)
end

Expand All @@ -298,10 +315,11 @@ function vmatrix(θ::AbstractVector{T}, lmm::LMM, i::Int) where T
Symmetric(V)
end
# For Multiple Imputation
function vmatrix::Vector, covstr::CovStructure, i::Int)
function vmatrix::Vector, covstr::CovStructure, lmmwts, i::Int)
V = zeros(length(covstr.vcovblock[i]), length(covstr.vcovblock[i]))
gvec = gmatvec(θ, covstr)
rmat_base_inc!(V, θ[covstr.tr[end]], covstr, i)
applywts!(V, i, lmmwts) #type unstable
zgz_base_inc!(V, gvec, covstr, i)
Symmetric(V)
end
Expand Down
42 changes: 21 additions & 21 deletions test/csv/df0.csv
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
subject,sequence,period,formulation,var,var2
1,1,1,1,1.0,1.0
1,1,2,2,1.1,2.0
1,1,3,1,1.2,3.0
1,1,4,2,1.3,4.0
2,1,1,1,2.0,5.0
2,1,2,2,2.1,6.0
2,1,3,1,2.4,7.0
2,1,4,2,2.2,2.0
3,2,1,2,1.3,3.5
3,2,2,1,1.5,3.0
3,2,3,2,1.6,4.0
3,2,4,1,1.4,5.0
4,2,1,2,1.5,6.0
4,2,2,1,1.7,1.2
4,2,3,2,1.3,3.4
4,2,4,1,1.4,5.0
5,2,1,2,1.5,6.0
5,2,2,1,1.7,7.0
5,2,3,2,1.2,10.0
5,2,4,1,1.8,9.0
"subject","sequence","period","formulation","var","var2","wts"
1,1,1,1,1,1,1
1,1,2,2,1.1,2,1
1,1,3,1,1.2,3,1
1,1,4,2,1.3,4,0.5
2,1,1,1,2,5,1
2,1,2,2,2.1,6,1
2,1,3,1,2.4,7,0.25
2,1,4,2,2.2,2,0.25
3,2,1,2,1.3,3.5,1
3,2,2,1,1.5,3,0.3
3,2,3,2,1.6,4,0.3
3,2,4,1,1.4,5,0.3
4,2,1,2,1.5,6,0.1
4,2,2,1,1.7,1.2,0.2
4,2,3,2,1.3,3.4,0.3
4,2,4,1,1.4,5,0.4
5,2,1,2,1.5,6,1
5,2,2,1,1.7,7,2
5,2,3,2,1.2,10,3
5,2,4,1,1.8,9,4
17 changes: 17 additions & 0 deletions test/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,23 @@ include("testdata.jl")
random = 1+var^2|subject:Metida.SI), df0)
Metida.fit!(lmmint)
@test Metida.m2logreml(lmmint) 84.23373276096902 atol=1E-6

# Wts

df0.wtsc = fill(0.5, size(df0, 1))
lmm = Metida.LMM(@formula(var~sequence+period+formulation), df0;
random = Metida.VarEffect(Metida.@covstr(formulation|subject), Metida.DIAG),
wts = df0.wtsc)
fit!(lmm)
@test Metida.m2logreml(lmm) 16.241112644506067 atol=1E-6

lmm = Metida.LMM(@formula(var~sequence+period+formulation), df0;
random = Metida.VarEffect(Metida.@covstr(formulation|subject), Metida.DIAG),
wts = "wts")
fit!(lmm)
@test Metida.m2logreml(lmm) 17.823729 atol=1E-6 # TEST WITH SPSS 28


end
################################################################################
# df0
Expand Down
2 changes: 1 addition & 1 deletion test/testdata.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Metida

#Simple dataset
df0 = CSV.File(path*"/csv/df0.csv"; types = [String, String, String, String, Float64, Float64]) |> DataFrame
df0 = CSV.File(path*"/csv/df0.csv"; types = [String, String, String, String, Float64, Float64, Float64]) |> DataFrame

df0m = CSV.File(path*"/csv/df0miss.csv"; types = [String, String, String, String, Float64, Float64]) |> DataFrame

Expand Down

0 comments on commit 8f7c5e5

Please sign in to comment.