diff --git a/src/RobustModels.jl b/src/RobustModels.jl index ac8e22b..d556390 100644 --- a/src/RobustModels.jl +++ b/src/RobustModels.jl @@ -191,8 +191,9 @@ abstract type RobustResp{T} <: ModResp end abstract type AbstractRegularizedPred{T} end -Base.broadcastable(m::T) where {T<:AbstractEstimator} = Ref(m) -Base.broadcastable(m::T) where {T<:LossFunction} = Ref(m) +Base.broadcastable(m::AbstractEstimator) = Ref(m) +Base.broadcastable(m::LossFunction) = Ref(m) +Base.broadcastable(m::PenaltyFunction) = Ref(m) include("tools.jl") include("losses.jl") diff --git a/src/penalties.jl b/src/penalties.jl index d7ba1fc..b9560c0 100644 --- a/src/penalties.jl +++ b/src/penalties.jl @@ -97,7 +97,7 @@ function proximal!(p::SquaredL2Penalty{T}, out, x::AbstractVector{T}, step::T=on # return broadcast!(/, out, x, 1 + p.λ * step) a = 1 / (1 + p.λ * step) @inbounds @simd for i in eachindex(out, x) - out[i] = x[i] > 0 ? x[i] * a : zero(T) + out[i] = (p.nonnegative && x[i] <= 0) ? zero(T) : x[i] * a end return out end @@ -122,7 +122,16 @@ struct EuclideanPenalty{T<:AbstractFloat} <: PenaltyFunction{T} end cost(p::EuclideanPenalty{T}, x::AbstractVector{T}) where {T<:AbstractFloat} = p.λ * norm(x, 2) function proximal!(p::EuclideanPenalty{T}, out, x::AbstractVector{T}, step::T=one(T)) where {T<:AbstractFloat} - nn = p.nonnegative ? norm(broadcast!(max, out, x, 0), 2) : norm(x, 2) +# nn = p.nonnegative ? norm(broadcast!(max, out, x, 0), 2) : norm(x, 2) + if p.nonnegative + nn = zero(T) + @inbounds @simd for xi in x + nn += xi > 0 ? xi^2 : zero(T) + end + nn = sqrt(nn) + else + nn = norm(x, 2) + end return rmul!(copyto!(out, x), (1 - p.λ * step / max(p.λ * step, nn))) end