Skip to content

Commit

Permalink
correct broadcast! call not working in julia-1.6
Browse files Browse the repository at this point in the history
  • Loading branch information
getzze committed May 12, 2023
1 parent 70ef5e7 commit 5bb7c3b
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
5 changes: 3 additions & 2 deletions src/RobustModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,9 @@ abstract type RobustResp{T} <: ModResp end

abstract type AbstractRegularizedPred{T} end

Base.broadcastable(m::T) where {T<:AbstractEstimator} = Ref(m)
Base.broadcastable(m::T) where {T<:LossFunction} = Ref(m)
Base.broadcastable(m::AbstractEstimator) = Ref(m)
Base.broadcastable(m::LossFunction) = Ref(m)
Base.broadcastable(m::PenaltyFunction) = Ref(m)

include("tools.jl")
include("losses.jl")
Expand Down
13 changes: 11 additions & 2 deletions src/penalties.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ function proximal!(p::SquaredL2Penalty{T}, out, x::AbstractVector{T}, step::T=on
# return broadcast!(/, out, x, 1 + p.λ * step)
a = 1 / (1 + p.λ * step)
@inbounds @simd for i in eachindex(out, x)
out[i] = x[i] > 0 ? x[i] * a : zero(T)
out[i] = (p.nonnegative && x[i] <= 0) ? zero(T) : x[i] * a
end
return out
end
Expand All @@ -122,7 +122,16 @@ struct EuclideanPenalty{T<:AbstractFloat} <: PenaltyFunction{T}
end
cost(p::EuclideanPenalty{T}, x::AbstractVector{T}) where {T<:AbstractFloat} = p.λ * norm(x, 2)
function proximal!(p::EuclideanPenalty{T}, out, x::AbstractVector{T}, step::T=one(T)) where {T<:AbstractFloat}
nn = p.nonnegative ? norm(broadcast!(max, out, x, 0), 2) : norm(x, 2)
# nn = p.nonnegative ? norm(broadcast!(max, out, x, 0), 2) : norm(x, 2)
if p.nonnegative
nn = zero(T)
@inbounds @simd for xi in x
nn += xi > 0 ? xi^2 : zero(T)
end
nn = sqrt(nn)
else
nn = norm(x, 2)
end
return rmul!(copyto!(out, x), (1 - p.λ * step / max(p.λ * step, nn)))
end

Expand Down

0 comments on commit 5bb7c3b

Please sign in to comment.